Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
77f2c42
1
Parent(s):
64eb375
去除chat_func文件,改用类控制模型
Browse files- ChuanhuChatbot.py +34 -71
- modules/__init__.py +0 -0
- modules/base_model.py +427 -0
- modules/chat_func.py +0 -497
- modules/config.py +1 -1
- modules/llama_func.py +38 -34
- modules/models.py +210 -0
- modules/openai_func.py +0 -65
- modules/presets.py +16 -28
- modules/utils.py +11 -96
ChuanhuChatbot.py
CHANGED
@@ -10,8 +10,7 @@ from modules.config import *
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
-
from modules.
|
14 |
-
from modules.openai_func import get_usage
|
15 |
|
16 |
gr.Chatbot.postprocess = postprocess
|
17 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
@@ -21,12 +20,11 @@ with open("assets/custom.css", "r", encoding="utf-8") as f:
|
|
21 |
|
22 |
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
23 |
user_name = gr.State("")
|
24 |
-
history = gr.State([])
|
25 |
-
token_count = gr.State([])
|
26 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
27 |
user_api_key = gr.State(my_api_key)
|
28 |
user_question = gr.State("")
|
29 |
-
|
|
|
30 |
topic = gr.State("未命名对话历史记录")
|
31 |
|
32 |
with gr.Row():
|
@@ -64,7 +62,6 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
64 |
retryBtn = gr.Button("🔄 重新生成")
|
65 |
delFirstBtn = gr.Button("🗑️ 删除最旧对话")
|
66 |
delLastBtn = gr.Button("🗑️ 删除最新对话")
|
67 |
-
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
68 |
|
69 |
with gr.Column():
|
70 |
with gr.Column(min_width=50, scale=1):
|
@@ -94,7 +91,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
94 |
multiselect=False,
|
95 |
value=REPLY_LANGUAGES[0],
|
96 |
)
|
97 |
-
index_files = gr.Files(label="上传索引文件", type="file"
|
98 |
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
99 |
# TODO: 公式ocr
|
100 |
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
@@ -104,7 +101,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
104 |
show_label=True,
|
105 |
placeholder=f"在这里输入System Prompt...",
|
106 |
label="System prompt",
|
107 |
-
value=
|
108 |
lines=10,
|
109 |
).style(container=False)
|
110 |
with gr.Accordion(label="加载Prompt模板", open=True):
|
@@ -202,23 +199,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
202 |
gr.Markdown(description)
|
203 |
gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
|
204 |
chatgpt_predict_args = dict(
|
205 |
-
fn=predict,
|
206 |
inputs=[
|
207 |
-
user_api_key,
|
208 |
-
systemPromptTxt,
|
209 |
-
history,
|
210 |
user_question,
|
211 |
chatbot,
|
212 |
-
token_count,
|
213 |
-
top_p,
|
214 |
-
temperature,
|
215 |
use_streaming_checkbox,
|
216 |
-
model_select_dropdown,
|
217 |
use_websearch_checkbox,
|
218 |
index_files,
|
219 |
language_select_dropdown,
|
220 |
],
|
221 |
-
outputs=[chatbot,
|
222 |
show_progress=True,
|
223 |
)
|
224 |
|
@@ -242,12 +232,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
242 |
)
|
243 |
|
244 |
get_usage_args = dict(
|
245 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
)
|
247 |
|
248 |
|
249 |
# Chatbot
|
250 |
-
cancelBtn.click(
|
251 |
|
252 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
253 |
user_input.submit(**get_usage_args)
|
@@ -256,63 +252,39 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
256 |
submitBtn.click(**get_usage_args)
|
257 |
|
258 |
emptyBtn.click(
|
259 |
-
|
260 |
-
outputs=[chatbot,
|
261 |
show_progress=True,
|
262 |
)
|
263 |
emptyBtn.click(**reset_textbox_args)
|
264 |
|
265 |
retryBtn.click(**start_outputing_args).then(
|
266 |
-
retry,
|
267 |
[
|
268 |
-
user_api_key,
|
269 |
-
systemPromptTxt,
|
270 |
-
history,
|
271 |
chatbot,
|
272 |
-
token_count,
|
273 |
-
top_p,
|
274 |
-
temperature,
|
275 |
use_streaming_checkbox,
|
276 |
-
|
|
|
277 |
language_select_dropdown,
|
278 |
],
|
279 |
-
[chatbot,
|
280 |
show_progress=True,
|
281 |
).then(**end_outputing_args)
|
282 |
retryBtn.click(**get_usage_args)
|
283 |
|
284 |
delFirstBtn.click(
|
285 |
-
delete_first_conversation,
|
286 |
-
|
287 |
-
[
|
288 |
)
|
289 |
|
290 |
delLastBtn.click(
|
291 |
-
delete_last_conversation,
|
292 |
-
[chatbot
|
293 |
-
[chatbot,
|
294 |
-
show_progress=
|
295 |
)
|
296 |
|
297 |
-
reduceTokenBtn.click(
|
298 |
-
reduce_token_size,
|
299 |
-
[
|
300 |
-
user_api_key,
|
301 |
-
systemPromptTxt,
|
302 |
-
history,
|
303 |
-
chatbot,
|
304 |
-
token_count,
|
305 |
-
top_p,
|
306 |
-
temperature,
|
307 |
-
gr.State(sum(token_count.value[-4:])),
|
308 |
-
model_select_dropdown,
|
309 |
-
language_select_dropdown,
|
310 |
-
],
|
311 |
-
[chatbot, history, status_display, token_count],
|
312 |
-
show_progress=True,
|
313 |
-
)
|
314 |
-
reduceTokenBtn.click(**get_usage_args)
|
315 |
-
|
316 |
two_column.change(update_doc_config, [two_column], None)
|
317 |
|
318 |
# ChatGPT
|
@@ -336,30 +308,21 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
336 |
|
337 |
# S&L
|
338 |
saveHistoryBtn.click(
|
339 |
-
save_chat_history,
|
340 |
-
[saveFileName,
|
341 |
downloadFile,
|
342 |
show_progress=True,
|
343 |
)
|
344 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
345 |
exportMarkdownBtn.click(
|
346 |
-
export_markdown,
|
347 |
-
[saveFileName,
|
348 |
downloadFile,
|
349 |
show_progress=True,
|
350 |
)
|
351 |
historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
352 |
-
historyFileSelectDropdown.change(
|
353 |
-
|
354 |
-
[historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
|
355 |
-
[saveFileName, systemPromptTxt, history, chatbot],
|
356 |
-
show_progress=True,
|
357 |
-
)
|
358 |
-
downloadFile.change(
|
359 |
-
load_chat_history,
|
360 |
-
[downloadFile, systemPromptTxt, history, chatbot, user_name],
|
361 |
-
[saveFileName, systemPromptTxt, history, chatbot],
|
362 |
-
)
|
363 |
|
364 |
# Advanced
|
365 |
default_btn.click(
|
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
+
from modules.models import get_model
|
|
|
14 |
|
15 |
gr.Chatbot.postprocess = postprocess
|
16 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
|
|
20 |
|
21 |
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
22 |
user_name = gr.State("")
|
|
|
|
|
23 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
24 |
user_api_key = gr.State(my_api_key)
|
25 |
user_question = gr.State("")
|
26 |
+
current_model = gr.State(get_model(MODELS[0], my_api_key))
|
27 |
+
|
28 |
topic = gr.State("未命名对话历史记录")
|
29 |
|
30 |
with gr.Row():
|
|
|
62 |
retryBtn = gr.Button("🔄 重新生成")
|
63 |
delFirstBtn = gr.Button("🗑️ 删除最旧对话")
|
64 |
delLastBtn = gr.Button("🗑️ 删除最新对话")
|
|
|
65 |
|
66 |
with gr.Column():
|
67 |
with gr.Column(min_width=50, scale=1):
|
|
|
91 |
multiselect=False,
|
92 |
value=REPLY_LANGUAGES[0],
|
93 |
)
|
94 |
+
index_files = gr.Files(label="上传索引文件", type="file")
|
95 |
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
96 |
# TODO: 公式ocr
|
97 |
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
|
|
101 |
show_label=True,
|
102 |
placeholder=f"在这里输入System Prompt...",
|
103 |
label="System prompt",
|
104 |
+
value=INITIAL_SYSTEM_PROMPT,
|
105 |
lines=10,
|
106 |
).style(container=False)
|
107 |
with gr.Accordion(label="加载Prompt模板", open=True):
|
|
|
199 |
gr.Markdown(description)
|
200 |
gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
|
201 |
chatgpt_predict_args = dict(
|
202 |
+
fn=current_model.value.predict,
|
203 |
inputs=[
|
|
|
|
|
|
|
204 |
user_question,
|
205 |
chatbot,
|
|
|
|
|
|
|
206 |
use_streaming_checkbox,
|
|
|
207 |
use_websearch_checkbox,
|
208 |
index_files,
|
209 |
language_select_dropdown,
|
210 |
],
|
211 |
+
outputs=[chatbot, status_display],
|
212 |
show_progress=True,
|
213 |
)
|
214 |
|
|
|
232 |
)
|
233 |
|
234 |
get_usage_args = dict(
|
235 |
+
fn=current_model.value.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
|
236 |
+
)
|
237 |
+
|
238 |
+
load_history_from_file_args = dict(
|
239 |
+
fn=current_model.value.load_chat_history,
|
240 |
+
inputs=[historyFileSelectDropdown, chatbot, user_name],
|
241 |
+
outputs=[saveFileName, systemPromptTxt, chatbot]
|
242 |
)
|
243 |
|
244 |
|
245 |
# Chatbot
|
246 |
+
cancelBtn.click(current_model.value.interrupt, [], [])
|
247 |
|
248 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
249 |
user_input.submit(**get_usage_args)
|
|
|
252 |
submitBtn.click(**get_usage_args)
|
253 |
|
254 |
emptyBtn.click(
|
255 |
+
current_model.value.reset,
|
256 |
+
outputs=[chatbot, status_display],
|
257 |
show_progress=True,
|
258 |
)
|
259 |
emptyBtn.click(**reset_textbox_args)
|
260 |
|
261 |
retryBtn.click(**start_outputing_args).then(
|
262 |
+
current_model.value.retry,
|
263 |
[
|
|
|
|
|
|
|
264 |
chatbot,
|
|
|
|
|
|
|
265 |
use_streaming_checkbox,
|
266 |
+
use_websearch_checkbox,
|
267 |
+
index_files,
|
268 |
language_select_dropdown,
|
269 |
],
|
270 |
+
[chatbot, status_display],
|
271 |
show_progress=True,
|
272 |
).then(**end_outputing_args)
|
273 |
retryBtn.click(**get_usage_args)
|
274 |
|
275 |
delFirstBtn.click(
|
276 |
+
current_model.value.delete_first_conversation,
|
277 |
+
None,
|
278 |
+
[status_display],
|
279 |
)
|
280 |
|
281 |
delLastBtn.click(
|
282 |
+
current_model.value.delete_last_conversation,
|
283 |
+
[chatbot],
|
284 |
+
[chatbot, status_display],
|
285 |
+
show_progress=False
|
286 |
)
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
two_column.change(update_doc_config, [two_column], None)
|
289 |
|
290 |
# ChatGPT
|
|
|
308 |
|
309 |
# S&L
|
310 |
saveHistoryBtn.click(
|
311 |
+
current_model.value.save_chat_history,
|
312 |
+
[saveFileName, chatbot, user_name],
|
313 |
downloadFile,
|
314 |
show_progress=True,
|
315 |
)
|
316 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
317 |
exportMarkdownBtn.click(
|
318 |
+
current_model.value.export_markdown,
|
319 |
+
[saveFileName, chatbot, user_name],
|
320 |
downloadFile,
|
321 |
show_progress=True,
|
322 |
)
|
323 |
historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
324 |
+
historyFileSelectDropdown.change(**load_history_from_file_args)
|
325 |
+
downloadFile.change(**load_history_from_file_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
# Advanced
|
328 |
default_btn.click(
|
modules/__init__.py
ADDED
File without changes
|
modules/base_model.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import TYPE_CHECKING, List
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import commentjson as cjson
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import requests
|
10 |
+
import urllib3
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
import colorama
|
14 |
+
from duckduckgo_search import ddg
|
15 |
+
import asyncio
|
16 |
+
import aiohttp
|
17 |
+
from enum import Enum
|
18 |
+
|
19 |
+
from .presets import *
|
20 |
+
from .llama_func import *
|
21 |
+
from .utils import *
|
22 |
+
from . import shared
|
23 |
+
from .config import retrieve_proxy
|
24 |
+
|
25 |
+
|
26 |
+
class ModelType(Enum):
|
27 |
+
OpenAI = 0
|
28 |
+
ChatGLM = 1
|
29 |
+
LLaMA = 2
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def get_type(cls, model_name: str):
|
33 |
+
model_type = None
|
34 |
+
if "gpt" in model_name.lower():
|
35 |
+
model_type = ModelType.OpenAI
|
36 |
+
elif "chatglm" in model_name.upper():
|
37 |
+
model_type = ModelType.ChatGLM
|
38 |
+
else:
|
39 |
+
model_type = ModelType.LLaMA
|
40 |
+
return model_type
|
41 |
+
|
42 |
+
|
43 |
+
class BaseLLMModel:
|
44 |
+
def __init__(self, model_name, temperature=1.0, top_p=1.0, max_generation_token=None, system_prompt="") -> None:
|
45 |
+
self.history = []
|
46 |
+
self.all_token_counts = []
|
47 |
+
self.model_name = model_name
|
48 |
+
self.model_type = ModelType.get_type(model_name)
|
49 |
+
self.api_key = None
|
50 |
+
self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
|
51 |
+
self.max_generation_token = max_generation_token if max_generation_token is not None else self.token_upper_limit
|
52 |
+
self.interrupted = False
|
53 |
+
self.temperature = temperature
|
54 |
+
self.top_p = top_p
|
55 |
+
self.system_prompt = system_prompt
|
56 |
+
|
57 |
+
|
58 |
+
def get_answer_stream_iter(self):
|
59 |
+
"""stream predict, need to be implemented
|
60 |
+
conversations are stored in self.history, with the most recent question, in OpenAI format
|
61 |
+
should return a generator, each time give the next word (str) in the answer
|
62 |
+
"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
def get_answer_at_once(self):
|
66 |
+
"""predict at once, need to be implemented
|
67 |
+
conversations are stored in self.history, with the most recent question, in OpenAI format
|
68 |
+
Should return:
|
69 |
+
the answer (str)
|
70 |
+
total token count (int)
|
71 |
+
"""
|
72 |
+
pass
|
73 |
+
|
74 |
+
def billing_info(self):
|
75 |
+
"""get billing infomation, inplement if needed"""
|
76 |
+
return billing_not_applicable_msg
|
77 |
+
|
78 |
+
|
79 |
+
def count_token(self, user_input):
|
80 |
+
"""get token count from input, implement if needed
|
81 |
+
"""
|
82 |
+
return 0
|
83 |
+
|
84 |
+
def stream_next_chatbot(
|
85 |
+
self, inputs, chatbot, fake_input=None, display_append=""
|
86 |
+
):
|
87 |
+
def get_return_value():
|
88 |
+
return chatbot, status_text
|
89 |
+
|
90 |
+
status_text = "开始实时传输回答……"
|
91 |
+
if fake_input:
|
92 |
+
chatbot.append((fake_input, ""))
|
93 |
+
else:
|
94 |
+
chatbot.append((inputs, ""))
|
95 |
+
|
96 |
+
user_token_count = self.count_token(inputs)
|
97 |
+
self.all_token_counts.append(user_token_count)
|
98 |
+
logging.debug(f"输入token计数: {user_token_count}")
|
99 |
+
|
100 |
+
stream_iter = self.get_answer_stream_iter()
|
101 |
+
|
102 |
+
for partial_text in stream_iter:
|
103 |
+
self.history[-1] = construct_assistant(partial_text)
|
104 |
+
chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
|
105 |
+
self.all_token_counts[-1] += 1
|
106 |
+
status_text = self.token_message()
|
107 |
+
yield get_return_value()
|
108 |
+
|
109 |
+
def next_chatbot_at_once(
|
110 |
+
self, inputs, chatbot, fake_input=None, display_append=""
|
111 |
+
):
|
112 |
+
if fake_input:
|
113 |
+
chatbot.append((fake_input, ""))
|
114 |
+
else:
|
115 |
+
chatbot.append((inputs, ""))
|
116 |
+
if fake_input is not None:
|
117 |
+
user_token_count = self.count_token(fake_input)
|
118 |
+
else:
|
119 |
+
user_token_count = self.count_token(inputs)
|
120 |
+
self.all_token_counts.append(user_token_count)
|
121 |
+
ai_reply, total_token_count = self.get_answer_at_once()
|
122 |
+
if fake_input is not None:
|
123 |
+
self.history[-2] = construct_user(fake_input)
|
124 |
+
self.history[-1] = construct_assistant(ai_reply)
|
125 |
+
chatbot[-1] = (chatbot[-1][0], ai_reply+display_append)
|
126 |
+
if fake_input is not None:
|
127 |
+
self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
|
128 |
+
else:
|
129 |
+
self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
|
130 |
+
status_text = self.token_message()
|
131 |
+
return chatbot, status_text
|
132 |
+
|
133 |
+
def predict(
|
134 |
+
self,
|
135 |
+
inputs,
|
136 |
+
chatbot,
|
137 |
+
stream=False,
|
138 |
+
use_websearch=False,
|
139 |
+
files=None,
|
140 |
+
reply_language="中文",
|
141 |
+
should_check_token_count=True,
|
142 |
+
): # repetition_penalty, top_k
|
143 |
+
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
144 |
+
from llama_index.indices.query.schema import QueryBundle
|
145 |
+
from langchain.llms import OpenAIChat
|
146 |
+
|
147 |
+
logging.info(
|
148 |
+
"输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
|
149 |
+
)
|
150 |
+
if should_check_token_count:
|
151 |
+
yield chatbot + [(inputs, "")], "开始生成回答……"
|
152 |
+
if reply_language == "跟随问题语言(不稳定)":
|
153 |
+
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
154 |
+
old_inputs = None
|
155 |
+
display_reference = []
|
156 |
+
limited_context = False
|
157 |
+
if files and self.api_key:
|
158 |
+
limited_context = True
|
159 |
+
old_inputs = inputs
|
160 |
+
msg = "加载索引中……(这可能需要几分钟)"
|
161 |
+
logging.info(msg)
|
162 |
+
yield chatbot + [(inputs, "")], msg
|
163 |
+
index = construct_index(self.api_key, file_src=files)
|
164 |
+
msg = "索引构建完成,获取回答中……"
|
165 |
+
logging.info(msg)
|
166 |
+
yield chatbot + [(inputs, "")], msg
|
167 |
+
with retrieve_proxy():
|
168 |
+
llm_predictor = LLMPredictor(
|
169 |
+
llm=OpenAIChat(temperature=0, model_name=self.model_name)
|
170 |
+
)
|
171 |
+
prompt_helper = PromptHelper(
|
172 |
+
max_input_size=4096,
|
173 |
+
num_output=5,
|
174 |
+
max_chunk_overlap=20,
|
175 |
+
chunk_size_limit=600,
|
176 |
+
)
|
177 |
+
from llama_index import ServiceContext
|
178 |
+
|
179 |
+
service_context = ServiceContext.from_defaults(
|
180 |
+
llm_predictor=llm_predictor, prompt_helper=prompt_helper
|
181 |
+
)
|
182 |
+
query_object = GPTVectorStoreIndexQuery(
|
183 |
+
index.index_struct,
|
184 |
+
service_context=service_context,
|
185 |
+
similarity_top_k=5,
|
186 |
+
vector_store=index._vector_store,
|
187 |
+
docstore=index._docstore,
|
188 |
+
)
|
189 |
+
query_bundle = QueryBundle(inputs)
|
190 |
+
nodes = query_object.retrieve(query_bundle)
|
191 |
+
reference_results = [n.node.text for n in nodes]
|
192 |
+
reference_results = add_source_numbers(reference_results, use_source=False)
|
193 |
+
display_reference = add_details(reference_results)
|
194 |
+
display_reference = "\n\n" + "".join(display_reference)
|
195 |
+
inputs = (
|
196 |
+
replace_today(PROMPT_TEMPLATE)
|
197 |
+
.replace("{query_str}", inputs)
|
198 |
+
.replace("{context_str}", "\n\n".join(reference_results))
|
199 |
+
.replace("{reply_language}", reply_language)
|
200 |
+
)
|
201 |
+
elif use_websearch:
|
202 |
+
limited_context = True
|
203 |
+
search_results = ddg(inputs, max_results=5)
|
204 |
+
old_inputs = inputs
|
205 |
+
reference_results = []
|
206 |
+
for idx, result in enumerate(search_results):
|
207 |
+
logging.debug(f"搜索结果{idx + 1}:{result}")
|
208 |
+
domain_name = urllib3.util.parse_url(result["href"]).host
|
209 |
+
reference_results.append([result["body"], result["href"]])
|
210 |
+
display_reference.append(
|
211 |
+
f"{idx+1}. [{domain_name}]({result['href']})\n"
|
212 |
+
)
|
213 |
+
reference_results = add_source_numbers(reference_results)
|
214 |
+
display_reference = "\n\n" + "".join(display_reference)
|
215 |
+
inputs = (
|
216 |
+
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
217 |
+
.replace("{query}", inputs)
|
218 |
+
.replace("{web_results}", "\n\n".join(reference_results))
|
219 |
+
.replace("{reply_language}", reply_language)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
display_reference = ""
|
223 |
+
|
224 |
+
if len(self.api_key) == 0 and not shared.state.multi_api_key:
|
225 |
+
status_text = standard_error_msg + no_apikey_msg
|
226 |
+
logging.info(status_text)
|
227 |
+
chatbot.append((inputs, ""))
|
228 |
+
if len(self.history) == 0:
|
229 |
+
self.history.append(construct_user(inputs))
|
230 |
+
self.history.append("")
|
231 |
+
self.all_token_counts.append(0)
|
232 |
+
else:
|
233 |
+
self.history[-2] = construct_user(inputs)
|
234 |
+
yield chatbot + [(inputs, "")], status_text
|
235 |
+
return
|
236 |
+
elif len(inputs.strip()) == 0:
|
237 |
+
status_text = standard_error_msg + no_input_msg
|
238 |
+
logging.info(status_text)
|
239 |
+
yield chatbot + [(inputs, "")], status_text
|
240 |
+
return
|
241 |
+
|
242 |
+
self.history.append(construct_user(inputs))
|
243 |
+
self.history.append(construct_assistant(""))
|
244 |
+
|
245 |
+
if stream:
|
246 |
+
logging.debug("使用流式传输")
|
247 |
+
iter = self.stream_next_chatbot(
|
248 |
+
inputs,
|
249 |
+
chatbot,
|
250 |
+
fake_input=old_inputs,
|
251 |
+
display_append=display_reference,
|
252 |
+
)
|
253 |
+
for chatbot, status_text in iter:
|
254 |
+
yield chatbot, status_text
|
255 |
+
if self.interrupted:
|
256 |
+
self.recover()
|
257 |
+
break
|
258 |
+
else:
|
259 |
+
logging.debug("不使用流式传输")
|
260 |
+
chatbot, status_text = self.next_chatbot_at_once(
|
261 |
+
inputs,
|
262 |
+
chatbot,
|
263 |
+
fake_input=old_inputs,
|
264 |
+
display_append=display_reference,
|
265 |
+
)
|
266 |
+
yield chatbot, status_text
|
267 |
+
|
268 |
+
if len(self.history) > 1 and self.history[-1]["content"] != inputs:
|
269 |
+
logging.info(
|
270 |
+
"回答为:"
|
271 |
+
+ colorama.Fore.BLUE
|
272 |
+
+ f"{self.history[-1]['content']}"
|
273 |
+
+ colorama.Style.RESET_ALL
|
274 |
+
)
|
275 |
+
|
276 |
+
if limited_context:
|
277 |
+
self.history = self.history[-4:]
|
278 |
+
self.all_token_counts = self.all_token_counts[-2:]
|
279 |
+
|
280 |
+
|
281 |
+
max_token = self.token_upper_limit - TOKEN_OFFSET
|
282 |
+
|
283 |
+
if sum(self.all_token_counts) > max_token and should_check_token_count:
|
284 |
+
count = 0
|
285 |
+
while sum(self.all_token_counts) > self.token_upper_limit * REDUCE_TOKEN_FACTOR and sum(self.all_token_counts) > 0:
|
286 |
+
count += 1
|
287 |
+
del self.all_token_counts[0]
|
288 |
+
del self.history[:2]
|
289 |
+
logging.info(status_text)
|
290 |
+
status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
|
291 |
+
yield chatbot, status_text
|
292 |
+
|
293 |
+
def retry(
|
294 |
+
self,
|
295 |
+
chatbot,
|
296 |
+
stream=False,
|
297 |
+
use_websearch=False,
|
298 |
+
files=None,
|
299 |
+
reply_language="中文",
|
300 |
+
):
|
301 |
+
logging.info("重试中……")
|
302 |
+
if len(self.history) == 0:
|
303 |
+
yield chatbot, f"{standard_error_msg}上下文是空的"
|
304 |
+
return
|
305 |
+
|
306 |
+
del self.history[-2:]
|
307 |
+
inputs = chatbot[-1][0]
|
308 |
+
self.all_token_counts.pop()
|
309 |
+
iter = self.predict(
|
310 |
+
inputs,
|
311 |
+
chatbot,
|
312 |
+
stream=stream,
|
313 |
+
use_websearch=use_websearch,
|
314 |
+
files=files,
|
315 |
+
reply_language=reply_language,
|
316 |
+
)
|
317 |
+
for x in iter:
|
318 |
+
yield x
|
319 |
+
logging.info("重试完毕")
|
320 |
+
|
321 |
+
# def reduce_token_size(self, chatbot):
|
322 |
+
# logging.info("开始减少token数量……")
|
323 |
+
# chatbot, status_text = self.next_chatbot_at_once(
|
324 |
+
# summarize_prompt,
|
325 |
+
# chatbot
|
326 |
+
# )
|
327 |
+
# max_token_count = self.token_upper_limit * REDUCE_TOKEN_FACTOR
|
328 |
+
# num_chat = find_n(self.all_token_counts, max_token_count)
|
329 |
+
# logging.info(f"previous_token_count: {self.all_token_counts}, keeping {num_chat} chats")
|
330 |
+
# chatbot = chatbot[:-1]
|
331 |
+
# self.history = self.history[-2*num_chat:] if num_chat > 0 else []
|
332 |
+
# self.all_token_counts = self.all_token_counts[-num_chat:] if num_chat > 0 else []
|
333 |
+
# msg = f"保留了最近{num_chat}轮对话"
|
334 |
+
# logging.info(msg)
|
335 |
+
# logging.info("减少token数量完毕")
|
336 |
+
# return chatbot, msg + "," + self.token_message(self.all_token_counts if len(self.all_token_counts) > 0 else [0])
|
337 |
+
|
338 |
+
def interrupt(self):
|
339 |
+
self.interrupted = True
|
340 |
+
|
341 |
+
def recover(self):
|
342 |
+
self.interrupted = False
|
343 |
+
|
344 |
+
def set_temprature(self, new_temprature):
|
345 |
+
self.temperature = new_temprature
|
346 |
+
|
347 |
+
def set_top_p(self, new_top_p):
|
348 |
+
self.top_p = new_top_p
|
349 |
+
|
350 |
+
def reset(self):
|
351 |
+
self.history = []
|
352 |
+
self.all_token_counts = []
|
353 |
+
self.interrupted = False
|
354 |
+
return [], self.token_message([0])
|
355 |
+
|
356 |
+
def delete_first_conversation(self):
|
357 |
+
if self.history:
|
358 |
+
del self.history[:2]
|
359 |
+
del self.all_token_counts[0]
|
360 |
+
return self.token_message()
|
361 |
+
|
362 |
+
def delete_last_conversation(self, chatbot):
|
363 |
+
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
364 |
+
msg = "由于包含报错信息,只删除chatbot记录"
|
365 |
+
chatbot.pop()
|
366 |
+
return chatbot, self.history
|
367 |
+
if len(self.history) > 0:
|
368 |
+
self.history.pop()
|
369 |
+
self.history.pop()
|
370 |
+
if len(chatbot) > 0:
|
371 |
+
msg = "删除了一组chatbot对话"
|
372 |
+
chatbot.pop()
|
373 |
+
if len(self.all_token_counts) > 0:
|
374 |
+
msg = "删除了一组对话的token计数记录"
|
375 |
+
self.all_token_counts.pop()
|
376 |
+
msg = "删除了一组对话"
|
377 |
+
return chatbot, msg
|
378 |
+
|
379 |
+
def token_message(self, token_lst = None):
|
380 |
+
if token_lst is None:
|
381 |
+
token_lst = self.all_token_counts
|
382 |
+
token_sum = 0
|
383 |
+
for i in range(len(token_lst)):
|
384 |
+
token_sum += sum(token_lst[: i + 1])
|
385 |
+
return f"Token 计数: {sum(token_lst)},本次对话累计消耗了 {token_sum} tokens"
|
386 |
+
|
387 |
+
def save_chat_history(self, filename, chatbot, user_name):
|
388 |
+
if filename == "":
|
389 |
+
return
|
390 |
+
if not filename.endswith(".json"):
|
391 |
+
filename += ".json"
|
392 |
+
return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
|
393 |
+
|
394 |
+
def export_markdown(self, filename, chatbot, user_name):
|
395 |
+
if filename == "":
|
396 |
+
return
|
397 |
+
if not filename.endswith(".md"):
|
398 |
+
filename += ".md"
|
399 |
+
return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
|
400 |
+
|
401 |
+
def load_chat_history(self, filename, chatbot, user_name):
|
402 |
+
logging.info(f"{user_name} 加载对话历史中……")
|
403 |
+
if type(filename) != str:
|
404 |
+
filename = filename.name
|
405 |
+
try:
|
406 |
+
with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
|
407 |
+
json_s = json.load(f)
|
408 |
+
try:
|
409 |
+
if type(json_s["history"][0]) == str:
|
410 |
+
logging.info("历史记录格式为旧版,正在转换……")
|
411 |
+
new_history = []
|
412 |
+
for index, item in enumerate(json_s["history"]):
|
413 |
+
if index % 2 == 0:
|
414 |
+
new_history.append(construct_user(item))
|
415 |
+
else:
|
416 |
+
new_history.append(construct_assistant(item))
|
417 |
+
json_s["history"] = new_history
|
418 |
+
logging.info(new_history)
|
419 |
+
except:
|
420 |
+
# 没有对话历史
|
421 |
+
pass
|
422 |
+
logging.info(f"{user_name} 加载对话历史完毕")
|
423 |
+
self.history = json_s["history"]
|
424 |
+
return filename, json_s["system"], json_s["chatbot"]
|
425 |
+
except FileNotFoundError:
|
426 |
+
logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
|
427 |
+
return filename, self.system_prompt, chatbot
|
modules/chat_func.py
DELETED
@@ -1,497 +0,0 @@
|
|
1 |
-
# -*- coding:utf-8 -*-
|
2 |
-
from __future__ import annotations
|
3 |
-
from typing import TYPE_CHECKING, List
|
4 |
-
|
5 |
-
import logging
|
6 |
-
import json
|
7 |
-
import os
|
8 |
-
import requests
|
9 |
-
import urllib3
|
10 |
-
|
11 |
-
from tqdm import tqdm
|
12 |
-
import colorama
|
13 |
-
from duckduckgo_search import ddg
|
14 |
-
import asyncio
|
15 |
-
import aiohttp
|
16 |
-
|
17 |
-
|
18 |
-
from modules.presets import *
|
19 |
-
from modules.llama_func import *
|
20 |
-
from modules.utils import *
|
21 |
-
from . import shared
|
22 |
-
from modules.config import retrieve_proxy
|
23 |
-
|
24 |
-
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
25 |
-
|
26 |
-
if TYPE_CHECKING:
|
27 |
-
from typing import TypedDict
|
28 |
-
|
29 |
-
class DataframeData(TypedDict):
|
30 |
-
headers: List[str]
|
31 |
-
data: List[List[str | int | bool]]
|
32 |
-
|
33 |
-
|
34 |
-
initial_prompt = "You are a helpful assistant."
|
35 |
-
HISTORY_DIR = "history"
|
36 |
-
TEMPLATES_DIR = "templates"
|
37 |
-
|
38 |
-
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
|
39 |
-
def get_response(
|
40 |
-
openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
|
41 |
-
):
|
42 |
-
headers = {
|
43 |
-
"Content-Type": "application/json",
|
44 |
-
"Authorization": f"Bearer {openai_api_key}",
|
45 |
-
}
|
46 |
-
|
47 |
-
history = [construct_system(system_prompt), *history]
|
48 |
-
|
49 |
-
payload = {
|
50 |
-
"model": selected_model,
|
51 |
-
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
|
52 |
-
"temperature": temperature, # 1.0,
|
53 |
-
"top_p": top_p, # 1.0,
|
54 |
-
"n": 1,
|
55 |
-
"stream": stream,
|
56 |
-
"presence_penalty": 0,
|
57 |
-
"frequency_penalty": 0,
|
58 |
-
}
|
59 |
-
if stream:
|
60 |
-
timeout = timeout_streaming
|
61 |
-
else:
|
62 |
-
timeout = timeout_all
|
63 |
-
|
64 |
-
|
65 |
-
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
66 |
-
if shared.state.completion_url != COMPLETION_URL:
|
67 |
-
logging.info(f"使用自定义API URL: {shared.state.completion_url}")
|
68 |
-
|
69 |
-
with retrieve_proxy():
|
70 |
-
response = requests.post(
|
71 |
-
shared.state.completion_url,
|
72 |
-
headers=headers,
|
73 |
-
json=payload,
|
74 |
-
stream=True,
|
75 |
-
timeout=timeout,
|
76 |
-
)
|
77 |
-
|
78 |
-
return response
|
79 |
-
|
80 |
-
|
81 |
-
def stream_predict(
|
82 |
-
openai_api_key,
|
83 |
-
system_prompt,
|
84 |
-
history,
|
85 |
-
inputs,
|
86 |
-
chatbot,
|
87 |
-
all_token_counts,
|
88 |
-
top_p,
|
89 |
-
temperature,
|
90 |
-
selected_model,
|
91 |
-
fake_input=None,
|
92 |
-
display_append=""
|
93 |
-
):
|
94 |
-
def get_return_value():
|
95 |
-
return chatbot, history, status_text, all_token_counts
|
96 |
-
|
97 |
-
logging.info("实时回答模式")
|
98 |
-
partial_words = ""
|
99 |
-
counter = 0
|
100 |
-
status_text = "开始实时传输回答……"
|
101 |
-
history.append(construct_user(inputs))
|
102 |
-
history.append(construct_assistant(""))
|
103 |
-
if fake_input:
|
104 |
-
chatbot.append((fake_input, ""))
|
105 |
-
else:
|
106 |
-
chatbot.append((inputs, ""))
|
107 |
-
user_token_count = 0
|
108 |
-
if fake_input is not None:
|
109 |
-
input_token_count = count_token(construct_user(fake_input))
|
110 |
-
else:
|
111 |
-
input_token_count = count_token(construct_user(inputs))
|
112 |
-
if len(all_token_counts) == 0:
|
113 |
-
system_prompt_token_count = count_token(construct_system(system_prompt))
|
114 |
-
user_token_count = (
|
115 |
-
input_token_count + system_prompt_token_count
|
116 |
-
)
|
117 |
-
else:
|
118 |
-
user_token_count = input_token_count
|
119 |
-
all_token_counts.append(user_token_count)
|
120 |
-
logging.info(f"输入token计数: {user_token_count}")
|
121 |
-
yield get_return_value()
|
122 |
-
try:
|
123 |
-
response = get_response(
|
124 |
-
openai_api_key,
|
125 |
-
system_prompt,
|
126 |
-
history,
|
127 |
-
temperature,
|
128 |
-
top_p,
|
129 |
-
True,
|
130 |
-
selected_model,
|
131 |
-
)
|
132 |
-
except requests.exceptions.ConnectTimeout:
|
133 |
-
status_text = (
|
134 |
-
standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
135 |
-
)
|
136 |
-
yield get_return_value()
|
137 |
-
return
|
138 |
-
except requests.exceptions.ReadTimeout:
|
139 |
-
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
140 |
-
yield get_return_value()
|
141 |
-
return
|
142 |
-
|
143 |
-
yield get_return_value()
|
144 |
-
error_json_str = ""
|
145 |
-
|
146 |
-
if fake_input is not None:
|
147 |
-
history[-2] = construct_user(fake_input)
|
148 |
-
for chunk in tqdm(response.iter_lines()):
|
149 |
-
if counter == 0:
|
150 |
-
counter += 1
|
151 |
-
continue
|
152 |
-
counter += 1
|
153 |
-
# check whether each line is non-empty
|
154 |
-
if chunk:
|
155 |
-
chunk = chunk.decode()
|
156 |
-
chunklength = len(chunk)
|
157 |
-
try:
|
158 |
-
chunk = json.loads(chunk[6:])
|
159 |
-
except json.JSONDecodeError:
|
160 |
-
logging.info(chunk)
|
161 |
-
error_json_str += chunk
|
162 |
-
status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
|
163 |
-
yield get_return_value()
|
164 |
-
continue
|
165 |
-
# decode each line as response data is in bytes
|
166 |
-
if chunklength > 6 and "delta" in chunk["choices"][0]:
|
167 |
-
finish_reason = chunk["choices"][0]["finish_reason"]
|
168 |
-
status_text = construct_token_message(all_token_counts)
|
169 |
-
if finish_reason == "stop":
|
170 |
-
yield get_return_value()
|
171 |
-
break
|
172 |
-
try:
|
173 |
-
partial_words = (
|
174 |
-
partial_words + chunk["choices"][0]["delta"]["content"]
|
175 |
-
)
|
176 |
-
except KeyError:
|
177 |
-
status_text = (
|
178 |
-
standard_error_msg
|
179 |
-
+ "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
|
180 |
-
+ str(sum(all_token_counts))
|
181 |
-
)
|
182 |
-
yield get_return_value()
|
183 |
-
break
|
184 |
-
history[-1] = construct_assistant(partial_words)
|
185 |
-
chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
|
186 |
-
all_token_counts[-1] += 1
|
187 |
-
yield get_return_value()
|
188 |
-
|
189 |
-
|
190 |
-
def predict_all(
|
191 |
-
openai_api_key,
|
192 |
-
system_prompt,
|
193 |
-
history,
|
194 |
-
inputs,
|
195 |
-
chatbot,
|
196 |
-
all_token_counts,
|
197 |
-
top_p,
|
198 |
-
temperature,
|
199 |
-
selected_model,
|
200 |
-
fake_input=None,
|
201 |
-
display_append=""
|
202 |
-
):
|
203 |
-
logging.info("一次性回答模式")
|
204 |
-
history.append(construct_user(inputs))
|
205 |
-
history.append(construct_assistant(""))
|
206 |
-
if fake_input:
|
207 |
-
chatbot.append((fake_input, ""))
|
208 |
-
else:
|
209 |
-
chatbot.append((inputs, ""))
|
210 |
-
if fake_input is not None:
|
211 |
-
all_token_counts.append(count_token(construct_user(fake_input)))
|
212 |
-
else:
|
213 |
-
all_token_counts.append(count_token(construct_user(inputs)))
|
214 |
-
try:
|
215 |
-
response = get_response(
|
216 |
-
openai_api_key,
|
217 |
-
system_prompt,
|
218 |
-
history,
|
219 |
-
temperature,
|
220 |
-
top_p,
|
221 |
-
False,
|
222 |
-
selected_model,
|
223 |
-
)
|
224 |
-
except requests.exceptions.ConnectTimeout:
|
225 |
-
status_text = (
|
226 |
-
standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
227 |
-
)
|
228 |
-
return chatbot, history, status_text, all_token_counts
|
229 |
-
except requests.exceptions.ProxyError:
|
230 |
-
status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
|
231 |
-
return chatbot, history, status_text, all_token_counts
|
232 |
-
except requests.exceptions.SSLError:
|
233 |
-
status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
|
234 |
-
return chatbot, history, status_text, all_token_counts
|
235 |
-
response = json.loads(response.text)
|
236 |
-
if fake_input is not None:
|
237 |
-
history[-2] = construct_user(fake_input)
|
238 |
-
try:
|
239 |
-
content = response["choices"][0]["message"]["content"]
|
240 |
-
history[-1] = construct_assistant(content)
|
241 |
-
chatbot[-1] = (chatbot[-1][0], content+display_append)
|
242 |
-
total_token_count = response["usage"]["total_tokens"]
|
243 |
-
if fake_input is not None:
|
244 |
-
all_token_counts[-1] += count_token(construct_assistant(content))
|
245 |
-
else:
|
246 |
-
all_token_counts[-1] = total_token_count - sum(all_token_counts)
|
247 |
-
status_text = construct_token_message(total_token_count)
|
248 |
-
return chatbot, history, status_text, all_token_counts
|
249 |
-
except KeyError:
|
250 |
-
status_text = standard_error_msg + str(response)
|
251 |
-
return chatbot, history, status_text, all_token_counts
|
252 |
-
|
253 |
-
|
254 |
-
def predict(
|
255 |
-
openai_api_key,
|
256 |
-
system_prompt,
|
257 |
-
history,
|
258 |
-
inputs,
|
259 |
-
chatbot,
|
260 |
-
all_token_counts,
|
261 |
-
top_p,
|
262 |
-
temperature,
|
263 |
-
stream=False,
|
264 |
-
selected_model=MODELS[0],
|
265 |
-
use_websearch=False,
|
266 |
-
files = None,
|
267 |
-
reply_language="中文",
|
268 |
-
should_check_token_count=True,
|
269 |
-
): # repetition_penalty, top_k
|
270 |
-
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
271 |
-
from llama_index.indices.query.schema import QueryBundle
|
272 |
-
from langchain.llms import OpenAIChat
|
273 |
-
|
274 |
-
|
275 |
-
logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
276 |
-
if should_check_token_count:
|
277 |
-
yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
|
278 |
-
if reply_language == "跟随问题语言(不稳定)":
|
279 |
-
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
280 |
-
old_inputs = None
|
281 |
-
display_reference = []
|
282 |
-
limited_context = False
|
283 |
-
if files:
|
284 |
-
limited_context = True
|
285 |
-
old_inputs = inputs
|
286 |
-
msg = "加载索引中……(这可能需要几分钟)"
|
287 |
-
logging.info(msg)
|
288 |
-
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
289 |
-
index = construct_index(openai_api_key, file_src=files)
|
290 |
-
msg = "索引构建完成,获取回答中……"
|
291 |
-
logging.info(msg)
|
292 |
-
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
293 |
-
with retrieve_proxy():
|
294 |
-
llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
|
295 |
-
prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
|
296 |
-
from llama_index import ServiceContext
|
297 |
-
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
|
298 |
-
query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
|
299 |
-
query_bundle = QueryBundle(inputs)
|
300 |
-
nodes = query_object.retrieve(query_bundle)
|
301 |
-
reference_results = [n.node.text for n in nodes]
|
302 |
-
reference_results = add_source_numbers(reference_results, use_source=False)
|
303 |
-
display_reference = add_details(reference_results)
|
304 |
-
display_reference = "\n\n" + "".join(display_reference)
|
305 |
-
inputs = (
|
306 |
-
replace_today(PROMPT_TEMPLATE)
|
307 |
-
.replace("{query_str}", inputs)
|
308 |
-
.replace("{context_str}", "\n\n".join(reference_results))
|
309 |
-
.replace("{reply_language}", reply_language )
|
310 |
-
)
|
311 |
-
elif use_websearch:
|
312 |
-
limited_context = True
|
313 |
-
search_results = ddg(inputs, max_results=5)
|
314 |
-
old_inputs = inputs
|
315 |
-
reference_results = []
|
316 |
-
for idx, result in enumerate(search_results):
|
317 |
-
logging.info(f"搜索结果{idx + 1}:{result}")
|
318 |
-
domain_name = urllib3.util.parse_url(result["href"]).host
|
319 |
-
reference_results.append([result["body"], result["href"]])
|
320 |
-
display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
|
321 |
-
reference_results = add_source_numbers(reference_results)
|
322 |
-
display_reference = "\n\n" + "".join(display_reference)
|
323 |
-
inputs = (
|
324 |
-
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
325 |
-
.replace("{query}", inputs)
|
326 |
-
.replace("{web_results}", "\n\n".join(reference_results))
|
327 |
-
.replace("{reply_language}", reply_language )
|
328 |
-
)
|
329 |
-
else:
|
330 |
-
display_reference = ""
|
331 |
-
|
332 |
-
if len(openai_api_key) == 0 and not shared.state.multi_api_key:
|
333 |
-
status_text = standard_error_msg + no_apikey_msg
|
334 |
-
logging.info(status_text)
|
335 |
-
chatbot.append((inputs, ""))
|
336 |
-
if len(history) == 0:
|
337 |
-
history.append(construct_user(inputs))
|
338 |
-
history.append("")
|
339 |
-
all_token_counts.append(0)
|
340 |
-
else:
|
341 |
-
history[-2] = construct_user(inputs)
|
342 |
-
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
|
343 |
-
return
|
344 |
-
elif len(inputs.strip()) == 0:
|
345 |
-
status_text = standard_error_msg + no_input_msg
|
346 |
-
logging.info(status_text)
|
347 |
-
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
|
348 |
-
return
|
349 |
-
|
350 |
-
if stream:
|
351 |
-
logging.info("使用流式传输")
|
352 |
-
iter = stream_predict(
|
353 |
-
openai_api_key,
|
354 |
-
system_prompt,
|
355 |
-
history,
|
356 |
-
inputs,
|
357 |
-
chatbot,
|
358 |
-
all_token_counts,
|
359 |
-
top_p,
|
360 |
-
temperature,
|
361 |
-
selected_model,
|
362 |
-
fake_input=old_inputs,
|
363 |
-
display_append=display_reference
|
364 |
-
)
|
365 |
-
for chatbot, history, status_text, all_token_counts in iter:
|
366 |
-
if shared.state.interrupted:
|
367 |
-
shared.state.recover()
|
368 |
-
return
|
369 |
-
yield chatbot, history, status_text, all_token_counts
|
370 |
-
else:
|
371 |
-
logging.info("不使用流式传输")
|
372 |
-
chatbot, history, status_text, all_token_counts = predict_all(
|
373 |
-
openai_api_key,
|
374 |
-
system_prompt,
|
375 |
-
history,
|
376 |
-
inputs,
|
377 |
-
chatbot,
|
378 |
-
all_token_counts,
|
379 |
-
top_p,
|
380 |
-
temperature,
|
381 |
-
selected_model,
|
382 |
-
fake_input=old_inputs,
|
383 |
-
display_append=display_reference
|
384 |
-
)
|
385 |
-
yield chatbot, history, status_text, all_token_counts
|
386 |
-
|
387 |
-
logging.info(f"传输完毕。当前token计数为{all_token_counts}")
|
388 |
-
if len(history) > 1 and history[-1]["content"] != inputs:
|
389 |
-
logging.info(
|
390 |
-
"回答为:"
|
391 |
-
+ colorama.Fore.BLUE
|
392 |
-
+ f"{history[-1]['content']}"
|
393 |
-
+ colorama.Style.RESET_ALL
|
394 |
-
)
|
395 |
-
|
396 |
-
if limited_context:
|
397 |
-
history = history[-4:]
|
398 |
-
all_token_counts = all_token_counts[-2:]
|
399 |
-
yield chatbot, history, status_text, all_token_counts
|
400 |
-
|
401 |
-
if stream:
|
402 |
-
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
|
403 |
-
else:
|
404 |
-
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
|
405 |
-
|
406 |
-
if sum(all_token_counts) > max_token and should_check_token_count:
|
407 |
-
print(all_token_counts)
|
408 |
-
count = 0
|
409 |
-
while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
|
410 |
-
count += 1
|
411 |
-
del all_token_counts[0]
|
412 |
-
del history[:2]
|
413 |
-
logging.info(status_text)
|
414 |
-
status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
|
415 |
-
yield chatbot, history, status_text, all_token_counts
|
416 |
-
|
417 |
-
|
418 |
-
def retry(
|
419 |
-
openai_api_key,
|
420 |
-
system_prompt,
|
421 |
-
history,
|
422 |
-
chatbot,
|
423 |
-
token_count,
|
424 |
-
top_p,
|
425 |
-
temperature,
|
426 |
-
stream=False,
|
427 |
-
selected_model=MODELS[0],
|
428 |
-
reply_language="中文",
|
429 |
-
):
|
430 |
-
logging.info("重试中……")
|
431 |
-
if len(history) == 0:
|
432 |
-
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
433 |
-
return
|
434 |
-
history.pop()
|
435 |
-
inputs = history.pop()["content"]
|
436 |
-
token_count.pop()
|
437 |
-
iter = predict(
|
438 |
-
openai_api_key,
|
439 |
-
system_prompt,
|
440 |
-
history,
|
441 |
-
inputs,
|
442 |
-
chatbot,
|
443 |
-
token_count,
|
444 |
-
top_p,
|
445 |
-
temperature,
|
446 |
-
stream=stream,
|
447 |
-
selected_model=selected_model,
|
448 |
-
reply_language=reply_language,
|
449 |
-
)
|
450 |
-
logging.info("重试中……")
|
451 |
-
for x in iter:
|
452 |
-
yield x
|
453 |
-
logging.info("重试完毕")
|
454 |
-
|
455 |
-
|
456 |
-
def reduce_token_size(
|
457 |
-
openai_api_key,
|
458 |
-
system_prompt,
|
459 |
-
history,
|
460 |
-
chatbot,
|
461 |
-
token_count,
|
462 |
-
top_p,
|
463 |
-
temperature,
|
464 |
-
max_token_count,
|
465 |
-
selected_model=MODELS[0],
|
466 |
-
reply_language="中文",
|
467 |
-
):
|
468 |
-
logging.info("开始减少token数量……")
|
469 |
-
iter = predict(
|
470 |
-
openai_api_key,
|
471 |
-
system_prompt,
|
472 |
-
history,
|
473 |
-
summarize_prompt,
|
474 |
-
chatbot,
|
475 |
-
token_count,
|
476 |
-
top_p,
|
477 |
-
temperature,
|
478 |
-
selected_model=selected_model,
|
479 |
-
should_check_token_count=False,
|
480 |
-
reply_language=reply_language,
|
481 |
-
)
|
482 |
-
logging.info(f"chatbot: {chatbot}")
|
483 |
-
flag = False
|
484 |
-
for chatbot, history, status_text, previous_token_count in iter:
|
485 |
-
num_chat = find_n(previous_token_count, max_token_count)
|
486 |
-
logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
|
487 |
-
if flag:
|
488 |
-
chatbot = chatbot[:-1]
|
489 |
-
flag = True
|
490 |
-
history = history[-2*num_chat:] if num_chat > 0 else []
|
491 |
-
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
|
492 |
-
msg = f"保留了最近{num_chat}轮对话"
|
493 |
-
yield chatbot, history, msg + "," + construct_token_message(
|
494 |
-
token_count if len(token_count) > 0 else [0],
|
495 |
-
), token_count
|
496 |
-
logging.info(msg)
|
497 |
-
logging.info("减少token数量完毕")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/config.py
CHANGED
@@ -3,7 +3,7 @@ from contextlib import contextmanager
|
|
3 |
import os
|
4 |
import logging
|
5 |
import sys
|
6 |
-
import json
|
7 |
|
8 |
from . import shared
|
9 |
|
|
|
3 |
import os
|
4 |
import logging
|
5 |
import sys
|
6 |
+
import commentjson as json
|
7 |
|
8 |
from . import shared
|
9 |
|
modules/llama_func.py
CHANGED
@@ -44,40 +44,44 @@ def get_documents(file_src):
|
|
44 |
filename = os.path.basename(filepath)
|
45 |
file_type = os.path.splitext(filepath)[1]
|
46 |
logging.info(f"loading file: {filename}")
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
text = add_space(text_raw)
|
82 |
# text = block_split(text)
|
83 |
# documents += text
|
|
|
44 |
filename = os.path.basename(filepath)
|
45 |
file_type = os.path.splitext(filepath)[1]
|
46 |
logging.info(f"loading file: {filename}")
|
47 |
+
try:
|
48 |
+
if file_type == ".pdf":
|
49 |
+
logging.debug("Loading PDF...")
|
50 |
+
try:
|
51 |
+
from modules.pdf_func import parse_pdf
|
52 |
+
from modules.config import advance_docs
|
53 |
+
two_column = advance_docs["pdf"].get("two_column", False)
|
54 |
+
pdftext = parse_pdf(filepath, two_column).text
|
55 |
+
except:
|
56 |
+
pdftext = ""
|
57 |
+
with open(filepath, 'rb') as pdfFileObj:
|
58 |
+
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
59 |
+
for page in tqdm(pdfReader.pages):
|
60 |
+
pdftext += page.extract_text()
|
61 |
+
text_raw = pdftext
|
62 |
+
elif file_type == ".docx":
|
63 |
+
logging.debug("Loading Word...")
|
64 |
+
DocxReader = download_loader("DocxReader")
|
65 |
+
loader = DocxReader()
|
66 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
67 |
+
elif file_type == ".epub":
|
68 |
+
logging.debug("Loading EPUB...")
|
69 |
+
EpubReader = download_loader("EpubReader")
|
70 |
+
loader = EpubReader()
|
71 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
72 |
+
elif file_type == ".xlsx":
|
73 |
+
logging.debug("Loading Excel...")
|
74 |
+
text_list = excel_to_string(filepath)
|
75 |
+
for elem in text_list:
|
76 |
+
documents.append(Document(elem))
|
77 |
+
continue
|
78 |
+
else:
|
79 |
+
logging.debug("Loading text file...")
|
80 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
81 |
+
text_raw = f.read()
|
82 |
+
except Exception as e:
|
83 |
+
logging.error(f"Error loading file: {filename}")
|
84 |
+
pass
|
85 |
text = add_space(text_raw)
|
86 |
# text = block_split(text)
|
87 |
# documents += text
|
modules/models.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import TYPE_CHECKING, List
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import commentjson as cjson
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import requests
|
10 |
+
import urllib3
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
import colorama
|
14 |
+
from duckduckgo_search import ddg
|
15 |
+
import asyncio
|
16 |
+
import aiohttp
|
17 |
+
from enum import Enum
|
18 |
+
|
19 |
+
from .presets import *
|
20 |
+
from .llama_func import *
|
21 |
+
from .utils import *
|
22 |
+
from . import shared
|
23 |
+
from .config import retrieve_proxy
|
24 |
+
from .base_model import BaseLLMModel, ModelType
|
25 |
+
|
26 |
+
|
27 |
+
class OpenAIClient(BaseLLMModel):
|
28 |
+
def __init__(
|
29 |
+
self, model_name, api_key, system_prompt=INITIAL_SYSTEM_PROMPT, temperature=1.0, top_p=1.0
|
30 |
+
) -> None:
|
31 |
+
super().__init__(model_name=model_name, temperature=temperature, top_p=top_p, system_prompt=system_prompt)
|
32 |
+
self.api_key = api_key
|
33 |
+
self.completion_url = shared.state.completion_url
|
34 |
+
self.usage_api_url = shared.state.usage_api_url
|
35 |
+
self.headers = {
|
36 |
+
"Content-Type": "application/json",
|
37 |
+
"Authorization": f"Bearer {self.api_key}",
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def get_answer_stream_iter(self):
|
42 |
+
response = self._get_response(stream=True)
|
43 |
+
if response is not None:
|
44 |
+
iter = self._decode_chat_response(response)
|
45 |
+
partial_text = ""
|
46 |
+
for i in iter:
|
47 |
+
partial_text += i
|
48 |
+
yield partial_text
|
49 |
+
else:
|
50 |
+
yield standard_error_msg + general_error_msg
|
51 |
+
|
52 |
+
def get_answer_at_once(self):
|
53 |
+
response = self._get_response()
|
54 |
+
response = json.loads(response.text)
|
55 |
+
content = response["choices"][0]["message"]["content"]
|
56 |
+
total_token_count = response["usage"]["total_tokens"]
|
57 |
+
return content, total_token_count
|
58 |
+
|
59 |
+
def count_token(self, user_input):
|
60 |
+
input_token_count = count_token(construct_user(user_input))
|
61 |
+
if self.system_prompt is not None and len(self.all_token_counts) == 0:
|
62 |
+
system_prompt_token_count = count_token(construct_system(self.system_prompt))
|
63 |
+
return input_token_count + system_prompt_token_count
|
64 |
+
return input_token_count
|
65 |
+
|
66 |
+
def set_system_prompt(self, new_system_prompt):
|
67 |
+
self.system_prompt = new_system_prompt
|
68 |
+
|
69 |
+
def billing_info(self):
|
70 |
+
try:
|
71 |
+
curr_time = datetime.datetime.now()
|
72 |
+
last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
|
73 |
+
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
|
74 |
+
usage_url = f"{self.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
75 |
+
try:
|
76 |
+
usage_data = self._get_billing_data(usage_url)
|
77 |
+
except Exception as e:
|
78 |
+
logging.error(f"获取API使用情况失败:"+str(e))
|
79 |
+
return f"**获取API使用情况失败**"
|
80 |
+
rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
|
81 |
+
return f"**本月使用金额** \u3000 ${rounded_usage}"
|
82 |
+
except requests.exceptions.ConnectTimeout:
|
83 |
+
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
84 |
+
return status_text
|
85 |
+
except requests.exceptions.ReadTimeout:
|
86 |
+
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
87 |
+
return status_text
|
88 |
+
except Exception as e:
|
89 |
+
logging.error(f"获取API使用情况失败:"+str(e))
|
90 |
+
return standard_error_msg + error_retrieve_prompt
|
91 |
+
|
92 |
+
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
|
93 |
+
def _get_response(self, stream=False):
|
94 |
+
openai_api_key = self.api_key
|
95 |
+
system_prompt = self.system_prompt
|
96 |
+
history = self.history
|
97 |
+
logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
|
98 |
+
temperature = self.temperature
|
99 |
+
top_p = self.top_p
|
100 |
+
selected_model = self.model_name
|
101 |
+
headers = {
|
102 |
+
"Content-Type": "application/json",
|
103 |
+
"Authorization": f"Bearer {openai_api_key}",
|
104 |
+
}
|
105 |
+
|
106 |
+
if system_prompt is not None:
|
107 |
+
history = [construct_system(system_prompt), *history]
|
108 |
+
|
109 |
+
payload = {
|
110 |
+
"model": selected_model,
|
111 |
+
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
|
112 |
+
"temperature": temperature, # 1.0,
|
113 |
+
"top_p": top_p, # 1.0,
|
114 |
+
"n": 1,
|
115 |
+
"stream": stream,
|
116 |
+
"presence_penalty": 0,
|
117 |
+
"frequency_penalty": 0,
|
118 |
+
}
|
119 |
+
if stream:
|
120 |
+
timeout = timeout_streaming
|
121 |
+
else:
|
122 |
+
timeout = TIMEOUT_ALL
|
123 |
+
|
124 |
+
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
125 |
+
if shared.state.completion_url != COMPLETION_URL:
|
126 |
+
logging.info(f"使用自定义API URL: {shared.state.completion_url}")
|
127 |
+
|
128 |
+
with retrieve_proxy():
|
129 |
+
try:
|
130 |
+
response = requests.post(
|
131 |
+
shared.state.completion_url,
|
132 |
+
headers=headers,
|
133 |
+
json=payload,
|
134 |
+
stream=stream,
|
135 |
+
timeout=timeout,
|
136 |
+
)
|
137 |
+
except:
|
138 |
+
return None
|
139 |
+
return response
|
140 |
+
|
141 |
+
def _get_billing_data(self, usage_url):
|
142 |
+
with retrieve_proxy():
|
143 |
+
response = requests.get(
|
144 |
+
usage_url,
|
145 |
+
headers=self.headers,
|
146 |
+
timeout=TIMEOUT_ALL,
|
147 |
+
)
|
148 |
+
|
149 |
+
if response.status_code == 200:
|
150 |
+
data = response.json()
|
151 |
+
return data
|
152 |
+
else:
|
153 |
+
raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
|
154 |
+
|
155 |
+
def _decode_chat_response(self, response):
|
156 |
+
for chunk in response.iter_lines():
|
157 |
+
if chunk:
|
158 |
+
chunk = chunk.decode()
|
159 |
+
chunk_length = len(chunk)
|
160 |
+
try:
|
161 |
+
chunk = json.loads(chunk[6:])
|
162 |
+
except json.JSONDecodeError:
|
163 |
+
print(f"JSON解析错误,收到的内容: {chunk}")
|
164 |
+
continue
|
165 |
+
if chunk_length > 6 and "delta" in chunk["choices"][0]:
|
166 |
+
if chunk["choices"][0]["finish_reason"] == "stop":
|
167 |
+
break
|
168 |
+
try:
|
169 |
+
yield chunk["choices"][0]["delta"]["content"]
|
170 |
+
except Exception as e:
|
171 |
+
# logging.error(f"Error: {e}")
|
172 |
+
continue
|
173 |
+
|
174 |
+
def get_model(model_name, access_key=None, temprature=None, top_p=None, system_prompt = None) -> BaseLLMModel:
|
175 |
+
model_type = ModelType.get_type(model_name)
|
176 |
+
if model_type == ModelType.OpenAI:
|
177 |
+
model = OpenAIClient(model_name, access_key, system_prompt, temprature, top_p)
|
178 |
+
return model
|
179 |
+
|
180 |
+
if __name__=="__main__":
|
181 |
+
with open("config.json", "r") as f:
|
182 |
+
openai_api_key = cjson.load(f)["openai_api_key"]
|
183 |
+
client = OpenAIClient("gpt-3.5-turbo", openai_api_key)
|
184 |
+
chatbot = []
|
185 |
+
stream = False
|
186 |
+
# 测试账单功能
|
187 |
+
print(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
|
188 |
+
print(client.billing_info())
|
189 |
+
# 测试问答
|
190 |
+
print(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
|
191 |
+
question = "巴黎是中国的首都吗?"
|
192 |
+
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
|
193 |
+
print(i)
|
194 |
+
print(f"测试问答后history : {client.history}")
|
195 |
+
# 测试记忆力
|
196 |
+
print(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
|
197 |
+
question = "我刚刚问了你什么问题?"
|
198 |
+
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
|
199 |
+
print(i)
|
200 |
+
print(f"测试记忆力后history : {client.history}")
|
201 |
+
# 测试重试功能
|
202 |
+
print(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
|
203 |
+
for i in client.retry(chatbot=chatbot, stream=stream):
|
204 |
+
print(i)
|
205 |
+
print(f"重试后history : {client.history}")
|
206 |
+
# # 测试总结功能
|
207 |
+
# print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
|
208 |
+
# chatbot, msg = client.reduce_token_size(chatbot=chatbot)
|
209 |
+
# print(chatbot, msg)
|
210 |
+
# print(f"总结后history: {client.history}")
|
modules/openai_func.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
import requests
|
2 |
-
import logging
|
3 |
-
from modules.presets import (
|
4 |
-
timeout_all,
|
5 |
-
USAGE_API_URL,
|
6 |
-
BALANCE_API_URL,
|
7 |
-
standard_error_msg,
|
8 |
-
connection_timeout_prompt,
|
9 |
-
error_retrieve_prompt,
|
10 |
-
read_timeout_prompt
|
11 |
-
)
|
12 |
-
|
13 |
-
from . import shared
|
14 |
-
from modules.config import retrieve_proxy
|
15 |
-
import os, datetime
|
16 |
-
|
17 |
-
def get_billing_data(openai_api_key, billing_url):
|
18 |
-
headers = {
|
19 |
-
"Content-Type": "application/json",
|
20 |
-
"Authorization": f"Bearer {openai_api_key}"
|
21 |
-
}
|
22 |
-
|
23 |
-
timeout = timeout_all
|
24 |
-
with retrieve_proxy():
|
25 |
-
response = requests.get(
|
26 |
-
billing_url,
|
27 |
-
headers=headers,
|
28 |
-
timeout=timeout,
|
29 |
-
)
|
30 |
-
|
31 |
-
if response.status_code == 200:
|
32 |
-
data = response.json()
|
33 |
-
return data
|
34 |
-
else:
|
35 |
-
raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
|
36 |
-
|
37 |
-
|
38 |
-
def get_usage(openai_api_key):
|
39 |
-
try:
|
40 |
-
curr_time = datetime.datetime.now()
|
41 |
-
last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
|
42 |
-
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
|
43 |
-
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
44 |
-
try:
|
45 |
-
usage_data = get_billing_data(openai_api_key, usage_url)
|
46 |
-
except Exception as e:
|
47 |
-
logging.error(f"获取API使用情况失败:"+str(e))
|
48 |
-
return f"**获取API使用情况失败**"
|
49 |
-
rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
|
50 |
-
return f"**本月使用金额** \u3000 ${rounded_usage}"
|
51 |
-
except requests.exceptions.ConnectTimeout:
|
52 |
-
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
53 |
-
return status_text
|
54 |
-
except requests.exceptions.ReadTimeout:
|
55 |
-
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
56 |
-
return status_text
|
57 |
-
except Exception as e:
|
58 |
-
logging.error(f"获取API使用情况失败:"+str(e))
|
59 |
-
return standard_error_msg + error_retrieve_prompt
|
60 |
-
|
61 |
-
def get_last_day_of_month(any_day):
|
62 |
-
# The day 28 exists in every month. 4 days later, it's always next month
|
63 |
-
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
64 |
-
# subtracting the number of the current day brings us back one month
|
65 |
-
return next_month - datetime.timedelta(days=next_month.day)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/presets.py
CHANGED
@@ -3,26 +3,29 @@ import gradio as gr
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
# ChatGPT 设置
|
6 |
-
|
7 |
API_HOST = "api.openai.com"
|
8 |
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
|
9 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
10 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
11 |
HISTORY_DIR = Path("history")
|
|
|
12 |
TEMPLATES_DIR = "templates"
|
13 |
|
14 |
# 错误信息
|
15 |
standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
|
16 |
-
|
|
|
17 |
connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
|
18 |
read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
|
19 |
proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
|
20 |
ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
|
21 |
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
22 |
no_input_msg = "请输入对话内容。" # 未输入对话内容
|
|
|
23 |
|
24 |
timeout_streaming = 10 # 流式对话时的超时时间
|
25 |
-
|
26 |
enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
|
27 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
28 |
CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
|
@@ -57,33 +60,18 @@ MODELS = [
|
|
57 |
"gpt-4-32k-0314",
|
58 |
] # 可选的模型
|
59 |
|
60 |
-
|
61 |
-
"gpt-3.5-turbo":
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
"gpt-
|
66 |
-
|
67 |
-
"all": 3500
|
68 |
-
},
|
69 |
-
"gpt-4": {
|
70 |
-
"streaming": 7500,
|
71 |
-
"all": 7500
|
72 |
-
},
|
73 |
-
"gpt-4-0314": {
|
74 |
-
"streaming": 7500,
|
75 |
-
"all": 7500
|
76 |
-
},
|
77 |
-
"gpt-4-32k": {
|
78 |
-
"streaming": 31000,
|
79 |
-
"all": 31000
|
80 |
-
},
|
81 |
-
"gpt-4-32k-0314": {
|
82 |
-
"streaming": 31000,
|
83 |
-
"all": 31000
|
84 |
-
}
|
85 |
}
|
86 |
|
|
|
|
|
|
|
87 |
REPLY_LANGUAGES = [
|
88 |
"简体中文",
|
89 |
"繁體中文",
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
# ChatGPT 设置
|
6 |
+
INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
|
7 |
API_HOST = "api.openai.com"
|
8 |
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
|
9 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
10 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
11 |
HISTORY_DIR = Path("history")
|
12 |
+
HISTORY_DIR = "history"
|
13 |
TEMPLATES_DIR = "templates"
|
14 |
|
15 |
# 错误信息
|
16 |
standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
|
17 |
+
general_error_msg = "获取对话时发生错误,请查看后台日志"
|
18 |
+
error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。"
|
19 |
connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
|
20 |
read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
|
21 |
proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
|
22 |
ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
|
23 |
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
24 |
no_input_msg = "请输入对话内容。" # 未输入对话内容
|
25 |
+
billing_not_applicable_msg = "模型本地运行中" # 本地运行的模型返回的账单信息
|
26 |
|
27 |
timeout_streaming = 10 # 流式对话时的超时时间
|
28 |
+
TIMEOUT_ALL = 200 # 非流式对话时的超时时间
|
29 |
enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
|
30 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
31 |
CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
|
|
|
60 |
"gpt-4-32k-0314",
|
61 |
] # 可选的模型
|
62 |
|
63 |
+
MODEL_TOKEN_LIMIT = {
|
64 |
+
"gpt-3.5-turbo": 4096,
|
65 |
+
"gpt-3.5-turbo-0301": 4096,
|
66 |
+
"gpt-4": 8192,
|
67 |
+
"gpt-4-0314": 8192,
|
68 |
+
"gpt-4-32k": 32768,
|
69 |
+
"gpt-4-32k-0314": 32768
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
}
|
71 |
|
72 |
+
TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
|
73 |
+
REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
|
74 |
+
|
75 |
REPLY_LANGUAGES = [
|
76 |
"简体中文",
|
77 |
"繁體中文",
|
modules/utils.py
CHANGED
@@ -153,47 +153,6 @@ def construct_assistant(text):
|
|
153 |
return construct_text("assistant", text)
|
154 |
|
155 |
|
156 |
-
def construct_token_message(tokens: List[int]):
|
157 |
-
token_sum = 0
|
158 |
-
for i in range(len(tokens)):
|
159 |
-
token_sum += sum(tokens[: i + 1])
|
160 |
-
return f"Token 计数: {sum(tokens)},本次对话累计消耗了 {token_sum} tokens"
|
161 |
-
|
162 |
-
|
163 |
-
def delete_first_conversation(history, previous_token_count):
|
164 |
-
if history:
|
165 |
-
del history[:2]
|
166 |
-
del previous_token_count[0]
|
167 |
-
return (
|
168 |
-
history,
|
169 |
-
previous_token_count,
|
170 |
-
construct_token_message(previous_token_count),
|
171 |
-
)
|
172 |
-
|
173 |
-
|
174 |
-
def delete_last_conversation(chatbot, history, previous_token_count):
|
175 |
-
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
176 |
-
logging.info("由于包含报错信息,只删除chatbot记录")
|
177 |
-
chatbot.pop()
|
178 |
-
return chatbot, history
|
179 |
-
if len(history) > 0:
|
180 |
-
logging.info("删除了一组对话历史")
|
181 |
-
history.pop()
|
182 |
-
history.pop()
|
183 |
-
if len(chatbot) > 0:
|
184 |
-
logging.info("删除了一组chatbot对话")
|
185 |
-
chatbot.pop()
|
186 |
-
if len(previous_token_count) > 0:
|
187 |
-
logging.info("删除了一组对话的token计数记录")
|
188 |
-
previous_token_count.pop()
|
189 |
-
return (
|
190 |
-
chatbot,
|
191 |
-
history,
|
192 |
-
previous_token_count,
|
193 |
-
construct_token_message(previous_token_count),
|
194 |
-
)
|
195 |
-
|
196 |
-
|
197 |
def save_file(filename, system, history, chatbot, user_name):
|
198 |
logging.info(f"{user_name} 保存对话历史中……")
|
199 |
os.makedirs(HISTORY_DIR / user_name, exist_ok=True)
|
@@ -212,56 +171,12 @@ def save_file(filename, system, history, chatbot, user_name):
|
|
212 |
return os.path.join(HISTORY_DIR / user_name, filename)
|
213 |
|
214 |
|
215 |
-
def save_chat_history(filename, system, history, chatbot, user_name):
|
216 |
-
if filename == "":
|
217 |
-
return
|
218 |
-
if not filename.endswith(".json"):
|
219 |
-
filename += ".json"
|
220 |
-
return save_file(filename, system, history, chatbot, user_name)
|
221 |
-
|
222 |
-
|
223 |
-
def export_markdown(filename, system, history, chatbot, user_name):
|
224 |
-
if filename == "":
|
225 |
-
return
|
226 |
-
if not filename.endswith(".md"):
|
227 |
-
filename += ".md"
|
228 |
-
return save_file(filename, system, history, chatbot, user_name)
|
229 |
-
|
230 |
-
|
231 |
-
def load_chat_history(filename, system, history, chatbot, user_name):
|
232 |
-
logging.info(f"{user_name} 加载对话历史中……")
|
233 |
-
if type(filename) != str:
|
234 |
-
filename = filename.name
|
235 |
-
try:
|
236 |
-
with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
|
237 |
-
json_s = json.load(f)
|
238 |
-
try:
|
239 |
-
if type(json_s["history"][0]) == str:
|
240 |
-
logging.info("历史记录格式为旧版,正在转换……")
|
241 |
-
new_history = []
|
242 |
-
for index, item in enumerate(json_s["history"]):
|
243 |
-
if index % 2 == 0:
|
244 |
-
new_history.append(construct_user(item))
|
245 |
-
else:
|
246 |
-
new_history.append(construct_assistant(item))
|
247 |
-
json_s["history"] = new_history
|
248 |
-
logging.info(new_history)
|
249 |
-
except:
|
250 |
-
# 没有对话历史
|
251 |
-
pass
|
252 |
-
logging.info(f"{user_name} 加载对话历史完毕")
|
253 |
-
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
254 |
-
except FileNotFoundError:
|
255 |
-
logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
|
256 |
-
return filename, system, history, chatbot
|
257 |
-
|
258 |
-
|
259 |
def sorted_by_pinyin(list):
|
260 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
261 |
|
262 |
|
263 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
264 |
-
logging.
|
265 |
files = []
|
266 |
try:
|
267 |
for type in filetypes:
|
@@ -279,14 +194,13 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
|
|
279 |
|
280 |
|
281 |
def get_history_names(plain=False, user_name=""):
|
282 |
-
logging.
|
283 |
-
return get_file_names(HISTORY_DIR
|
284 |
|
285 |
|
286 |
def load_template(filename, mode=0):
|
287 |
-
logging.
|
288 |
lines = []
|
289 |
-
logging.info("Loading template...")
|
290 |
if filename.endswith(".json"):
|
291 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
292 |
lines = json.load(f)
|
@@ -310,7 +224,7 @@ def load_template(filename, mode=0):
|
|
310 |
|
311 |
|
312 |
def get_template_names(plain=False):
|
313 |
-
logging.
|
314 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
315 |
|
316 |
|
@@ -322,11 +236,6 @@ def get_template_content(templates, selection, original_system_prompt):
|
|
322 |
return original_system_prompt
|
323 |
|
324 |
|
325 |
-
def reset_state():
|
326 |
-
logging.info("重置状态")
|
327 |
-
return [], [], [], construct_token_message([0])
|
328 |
-
|
329 |
-
|
330 |
def reset_textbox():
|
331 |
logging.debug("重置文本框")
|
332 |
return gr.update(value="")
|
@@ -530,3 +439,9 @@ def excel_to_string(file_path):
|
|
530 |
|
531 |
|
532 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
return construct_text("assistant", text)
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def save_file(filename, system, history, chatbot, user_name):
|
157 |
logging.info(f"{user_name} 保存对话历史中……")
|
158 |
os.makedirs(HISTORY_DIR / user_name, exist_ok=True)
|
|
|
171 |
return os.path.join(HISTORY_DIR / user_name, filename)
|
172 |
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
def sorted_by_pinyin(list):
|
175 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
176 |
|
177 |
|
178 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
179 |
+
logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
|
180 |
files = []
|
181 |
try:
|
182 |
for type in filetypes:
|
|
|
194 |
|
195 |
|
196 |
def get_history_names(plain=False, user_name=""):
|
197 |
+
logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
|
198 |
+
return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
|
199 |
|
200 |
|
201 |
def load_template(filename, mode=0):
|
202 |
+
logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
|
203 |
lines = []
|
|
|
204 |
if filename.endswith(".json"):
|
205 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
206 |
lines = json.load(f)
|
|
|
224 |
|
225 |
|
226 |
def get_template_names(plain=False):
|
227 |
+
logging.debug("获取模���文件名列表")
|
228 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
229 |
|
230 |
|
|
|
236 |
return original_system_prompt
|
237 |
|
238 |
|
|
|
|
|
|
|
|
|
|
|
239 |
def reset_textbox():
|
240 |
logging.debug("重置文本框")
|
241 |
return gr.update(value="")
|
|
|
439 |
|
440 |
|
441 |
return result
|
442 |
+
|
443 |
+
def get_last_day_of_month(any_day):
|
444 |
+
# The day 28 exists in every month. 4 days later, it's always next month
|
445 |
+
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
446 |
+
# subtracting the number of the current day brings us back one month
|
447 |
+
return next_month - datetime.timedelta(days=next_month.day)
|