Tuchuanhuhuhu commited on
Commit
f2c2a56
·
1 Parent(s): c857ac1

让新增的参数们真正有用

Browse files
Files changed (3) hide show
  1. ChuanhuChatbot.py +10 -2
  2. modules/base_model.py +33 -7
  3. modules/models.py +17 -11
ChuanhuChatbot.py CHANGED
@@ -221,7 +221,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
221
  value="",
222
  lines=1,
223
  )
224
- user = gr.Textbox(
225
  show_label=True,
226
  placeholder=f"用于定位滥用行为",
227
  label="用户名",
@@ -379,8 +379,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
379
  downloadFile.change(**load_history_from_file_args)
380
 
381
  # Advanced
382
- top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
383
  temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
 
 
 
 
 
 
 
 
 
384
  default_btn.click(
385
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
386
  )
 
221
  value="",
222
  lines=1,
223
  )
224
+ user_identifier_txt = gr.Textbox(
225
  show_label=True,
226
  placeholder=f"用于定位滥用行为",
227
  label="用户名",
 
379
  downloadFile.change(**load_history_from_file_args)
380
 
381
  # Advanced
 
382
  temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
383
+ top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
384
+ n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
385
+ stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
386
+ max_tokens_slider.change(current_model.value.set_max_tokens, [max_tokens_slider], None)
387
+ presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
388
+ frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
389
+ logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
390
+ user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
391
+
392
  default_btn.click(
393
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
394
  )
modules/base_model.py CHANGED
@@ -67,16 +67,12 @@ class BaseLLMModel:
67
  self.temperature = temperature
68
  self.top_p = top_p
69
  self.n_choices = n_choices
70
- self.stop = stop
71
- self.max_generation_token = (
72
- max_generation_token
73
- if max_generation_token is not None
74
- else self.token_upper_limit
75
- )
76
  self.presence_penalty = presence_penalty
77
  self.frequency_penalty = frequency_penalty
78
  self.logit_bias = logit_bias
79
- self.user = user
80
 
81
  def get_answer_stream_iter(self):
82
  """stream predict, need to be implemented
@@ -367,6 +363,36 @@ class BaseLLMModel:
367
  def set_top_p(self, new_top_p):
368
  self.top_p = new_top_p
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def set_system_prompt(self, new_system_prompt):
371
  self.system_prompt = new_system_prompt
372
 
 
67
  self.temperature = temperature
68
  self.top_p = top_p
69
  self.n_choices = n_choices
70
+ self.stop_sequence = stop
71
+ self.max_generation_token = None
 
 
 
 
72
  self.presence_penalty = presence_penalty
73
  self.frequency_penalty = frequency_penalty
74
  self.logit_bias = logit_bias
75
+ self.user_identifier = user
76
 
77
  def get_answer_stream_iter(self):
78
  """stream predict, need to be implemented
 
363
  def set_top_p(self, new_top_p):
364
  self.top_p = new_top_p
365
 
366
+ def set_n_choices(self, new_n_choices):
367
+ self.n_choices = new_n_choices
368
+
369
+ def set_stop_sequence(self, new_stop_sequence: str):
370
+ new_stop_sequence = new_stop_sequence.split(",")
371
+ self.stop_sequence = new_stop_sequence
372
+
373
+ def set_max_tokens(self, new_max_tokens):
374
+ self.max_generation_token = new_max_tokens
375
+
376
+ def set_presence_penalty(self, new_presence_penalty):
377
+ self.presence_penalty = new_presence_penalty
378
+
379
+ def set_frequency_penalty(self, new_frequency_penalty):
380
+ self.frequency_penalty = new_frequency_penalty
381
+
382
+ def set_logit_bias(self, logit_bias):
383
+ logit_bias = logit_bias.split()
384
+ bias_map = {}
385
+ encoding = tiktoken.get_encoding("cl100k_base")
386
+ for line in logit_bias:
387
+ word, bias_amount = line.split(":")
388
+ if word:
389
+ for token in encoding.encode(word):
390
+ bias_map[token] = float(bias_amount)
391
+ self.logit_bias = bias_map
392
+
393
+ def set_user_identifier(self, new_user_identifier):
394
+ self.user_identifier = new_user_identifier
395
+
396
  def set_system_prompt(self, new_system_prompt):
397
  self.system_prompt = new_system_prompt
398
 
modules/models.py CHANGED
@@ -103,9 +103,6 @@ class OpenAIClient(BaseLLMModel):
103
  system_prompt = self.system_prompt
104
  history = self.history
105
  logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
106
- temperature = self.temperature
107
- top_p = self.top_p
108
- selected_model = self.model_name
109
  headers = {
110
  "Content-Type": "application/json",
111
  "Authorization": f"Bearer {openai_api_key}",
@@ -115,16 +112,25 @@ class OpenAIClient(BaseLLMModel):
115
  history = [construct_system(system_prompt), *history]
116
 
117
  payload = {
118
- "model": selected_model,
119
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
120
- "temperature": temperature, # 1.0,
121
- "top_p": top_p, # 1.0,
122
- "n": 1,
123
  "stream": stream,
124
- "presence_penalty": 0,
125
- "frequency_penalty": 0,
126
- "max_tokens": self.max_generation_token,
127
  }
 
 
 
 
 
 
 
 
 
 
128
  if stream:
129
  timeout = TIMEOUT_STREAMING
130
  else:
 
103
  system_prompt = self.system_prompt
104
  history = self.history
105
  logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
 
 
 
106
  headers = {
107
  "Content-Type": "application/json",
108
  "Authorization": f"Bearer {openai_api_key}",
 
112
  history = [construct_system(system_prompt), *history]
113
 
114
  payload = {
115
+ "model": self.model_name,
116
+ "messages": history,
117
+ "temperature": self.temperature,
118
+ "top_p": self.top_p,
119
+ "n": self.n_choices,
120
  "stream": stream,
121
+ "presence_penalty": self.presence_penalty,
122
+ "frequency_penalty": self.frequency_penalty,
 
123
  }
124
+
125
+ if self.max_generation_token is not None:
126
+ payload["max_tokens"] = self.max_generation_token
127
+ if self.stop_sequence is not None:
128
+ payload["stop"] = self.stop_sequence
129
+ if self.logit_bias is not None:
130
+ payload["logit_bias"] = self.logit_bias
131
+ if self.user_identifier is not None:
132
+ payload["user"] = self.user_identifier
133
+
134
  if stream:
135
  timeout = TIMEOUT_STREAMING
136
  else: