Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
f2c2a56
1
Parent(s):
c857ac1
让新增的参数们真正有用
Browse files- ChuanhuChatbot.py +10 -2
- modules/base_model.py +33 -7
- modules/models.py +17 -11
ChuanhuChatbot.py
CHANGED
@@ -221,7 +221,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
221 |
value="",
|
222 |
lines=1,
|
223 |
)
|
224 |
-
|
225 |
show_label=True,
|
226 |
placeholder=f"用于定位滥用行为",
|
227 |
label="用户名",
|
@@ -379,8 +379,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
379 |
downloadFile.change(**load_history_from_file_args)
|
380 |
|
381 |
# Advanced
|
382 |
-
top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
|
383 |
temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
default_btn.click(
|
385 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
386 |
)
|
|
|
221 |
value="",
|
222 |
lines=1,
|
223 |
)
|
224 |
+
user_identifier_txt = gr.Textbox(
|
225 |
show_label=True,
|
226 |
placeholder=f"用于定位滥用行为",
|
227 |
label="用户名",
|
|
|
379 |
downloadFile.change(**load_history_from_file_args)
|
380 |
|
381 |
# Advanced
|
|
|
382 |
temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
|
383 |
+
top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
|
384 |
+
n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
|
385 |
+
stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
|
386 |
+
max_tokens_slider.change(current_model.value.set_max_tokens, [max_tokens_slider], None)
|
387 |
+
presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
|
388 |
+
frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
|
389 |
+
logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
|
390 |
+
user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
|
391 |
+
|
392 |
default_btn.click(
|
393 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
394 |
)
|
modules/base_model.py
CHANGED
@@ -67,16 +67,12 @@ class BaseLLMModel:
|
|
67 |
self.temperature = temperature
|
68 |
self.top_p = top_p
|
69 |
self.n_choices = n_choices
|
70 |
-
self.
|
71 |
-
self.max_generation_token =
|
72 |
-
max_generation_token
|
73 |
-
if max_generation_token is not None
|
74 |
-
else self.token_upper_limit
|
75 |
-
)
|
76 |
self.presence_penalty = presence_penalty
|
77 |
self.frequency_penalty = frequency_penalty
|
78 |
self.logit_bias = logit_bias
|
79 |
-
self.
|
80 |
|
81 |
def get_answer_stream_iter(self):
|
82 |
"""stream predict, need to be implemented
|
@@ -367,6 +363,36 @@ class BaseLLMModel:
|
|
367 |
def set_top_p(self, new_top_p):
|
368 |
self.top_p = new_top_p
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
def set_system_prompt(self, new_system_prompt):
|
371 |
self.system_prompt = new_system_prompt
|
372 |
|
|
|
67 |
self.temperature = temperature
|
68 |
self.top_p = top_p
|
69 |
self.n_choices = n_choices
|
70 |
+
self.stop_sequence = stop
|
71 |
+
self.max_generation_token = None
|
|
|
|
|
|
|
|
|
72 |
self.presence_penalty = presence_penalty
|
73 |
self.frequency_penalty = frequency_penalty
|
74 |
self.logit_bias = logit_bias
|
75 |
+
self.user_identifier = user
|
76 |
|
77 |
def get_answer_stream_iter(self):
|
78 |
"""stream predict, need to be implemented
|
|
|
363 |
def set_top_p(self, new_top_p):
|
364 |
self.top_p = new_top_p
|
365 |
|
366 |
+
def set_n_choices(self, new_n_choices):
|
367 |
+
self.n_choices = new_n_choices
|
368 |
+
|
369 |
+
def set_stop_sequence(self, new_stop_sequence: str):
|
370 |
+
new_stop_sequence = new_stop_sequence.split(",")
|
371 |
+
self.stop_sequence = new_stop_sequence
|
372 |
+
|
373 |
+
def set_max_tokens(self, new_max_tokens):
|
374 |
+
self.max_generation_token = new_max_tokens
|
375 |
+
|
376 |
+
def set_presence_penalty(self, new_presence_penalty):
|
377 |
+
self.presence_penalty = new_presence_penalty
|
378 |
+
|
379 |
+
def set_frequency_penalty(self, new_frequency_penalty):
|
380 |
+
self.frequency_penalty = new_frequency_penalty
|
381 |
+
|
382 |
+
def set_logit_bias(self, logit_bias):
|
383 |
+
logit_bias = logit_bias.split()
|
384 |
+
bias_map = {}
|
385 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
386 |
+
for line in logit_bias:
|
387 |
+
word, bias_amount = line.split(":")
|
388 |
+
if word:
|
389 |
+
for token in encoding.encode(word):
|
390 |
+
bias_map[token] = float(bias_amount)
|
391 |
+
self.logit_bias = bias_map
|
392 |
+
|
393 |
+
def set_user_identifier(self, new_user_identifier):
|
394 |
+
self.user_identifier = new_user_identifier
|
395 |
+
|
396 |
def set_system_prompt(self, new_system_prompt):
|
397 |
self.system_prompt = new_system_prompt
|
398 |
|
modules/models.py
CHANGED
@@ -103,9 +103,6 @@ class OpenAIClient(BaseLLMModel):
|
|
103 |
system_prompt = self.system_prompt
|
104 |
history = self.history
|
105 |
logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
|
106 |
-
temperature = self.temperature
|
107 |
-
top_p = self.top_p
|
108 |
-
selected_model = self.model_name
|
109 |
headers = {
|
110 |
"Content-Type": "application/json",
|
111 |
"Authorization": f"Bearer {openai_api_key}",
|
@@ -115,16 +112,25 @@ class OpenAIClient(BaseLLMModel):
|
|
115 |
history = [construct_system(system_prompt), *history]
|
116 |
|
117 |
payload = {
|
118 |
-
"model":
|
119 |
-
"messages": history,
|
120 |
-
"temperature": temperature,
|
121 |
-
"top_p": top_p,
|
122 |
-
"n":
|
123 |
"stream": stream,
|
124 |
-
"presence_penalty":
|
125 |
-
"frequency_penalty":
|
126 |
-
"max_tokens": self.max_generation_token,
|
127 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
if stream:
|
129 |
timeout = TIMEOUT_STREAMING
|
130 |
else:
|
|
|
103 |
system_prompt = self.system_prompt
|
104 |
history = self.history
|
105 |
logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
|
|
|
|
|
|
|
106 |
headers = {
|
107 |
"Content-Type": "application/json",
|
108 |
"Authorization": f"Bearer {openai_api_key}",
|
|
|
112 |
history = [construct_system(system_prompt), *history]
|
113 |
|
114 |
payload = {
|
115 |
+
"model": self.model_name,
|
116 |
+
"messages": history,
|
117 |
+
"temperature": self.temperature,
|
118 |
+
"top_p": self.top_p,
|
119 |
+
"n": self.n_choices,
|
120 |
"stream": stream,
|
121 |
+
"presence_penalty": self.presence_penalty,
|
122 |
+
"frequency_penalty": self.frequency_penalty,
|
|
|
123 |
}
|
124 |
+
|
125 |
+
if self.max_generation_token is not None:
|
126 |
+
payload["max_tokens"] = self.max_generation_token
|
127 |
+
if self.stop_sequence is not None:
|
128 |
+
payload["stop"] = self.stop_sequence
|
129 |
+
if self.logit_bias is not None:
|
130 |
+
payload["logit_bias"] = self.logit_bias
|
131 |
+
if self.user_identifier is not None:
|
132 |
+
payload["user"] = self.user_identifier
|
133 |
+
|
134 |
if stream:
|
135 |
timeout = TIMEOUT_STREAMING
|
136 |
else:
|