Tuchuanhuhuhu commited on
Commit
2c7dccc
·
1 Parent(s): 76a432f

feat: 从配置文件加载保存的参数

Browse files
ChuanhuChatbot.py CHANGED
@@ -499,14 +499,12 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
499
  model_name=MODELS[DEFAULT_MODEL], access_key=my_api_key)[0]
500
  current_model.set_user_identifier(user_name)
501
  if not hide_history_when_not_logged_in or user_name:
502
- filename, system_prompt, chatbot = current_model.auto_load()
503
  else:
504
- system_prompt = gr.update()
505
- filename = gr.update()
506
- chatbot = gr.Chatbot.update(label=MODELS[DEFAULT_MODEL])
507
- return user_info, user_name, current_model, toggle_like_btn_visibility(DEFAULT_MODEL), filename, system_prompt, chatbot, init_history_list(user_name)
508
  demo.load(create_greeting, inputs=None, outputs=[
509
- user_info, user_name, current_model, like_dislike_area, saveFileName, systemPromptTxt, chatbot, historySelectList], api_name="load")
510
  chatgpt_predict_args = dict(
511
  fn=predict,
512
  inputs=[
@@ -550,7 +548,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
550
  load_history_from_file_args = dict(
551
  fn=load_chat_history,
552
  inputs=[current_model, historySelectList],
553
- outputs=[saveFileName, systemPromptTxt, chatbot]
554
  )
555
 
556
  refresh_history_args = dict(
@@ -587,7 +585,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
587
  emptyBtn.click(
588
  reset,
589
  inputs=[current_model, retain_system_prompt_checkbox],
590
- outputs=[chatbot, status_display, historySelectList, systemPromptTxt],
591
  show_progress=True,
592
  _js='(a,b)=>{return clearChatbot(a,b);}',
593
  )
@@ -693,7 +691,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
693
  )
694
  historySelectList.input(**load_history_from_file_args)
695
  uploadFileBtn.upload(upload_chat_history, [current_model, uploadFileBtn], [
696
- saveFileName, systemPromptTxt, chatbot]).then(**refresh_history_args)
697
  historyDownloadBtn.click(None, [
698
  user_name, historySelectList], None, _js='(a,b)=>{return downloadHistory(a,b,".json");}')
699
  historyMarkdownDownloadBtn.click(None, [
@@ -725,24 +723,24 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
725
  cancel_all_jobs, [], [openai_train_status], show_progress=True)
726
 
727
  # Advanced
728
- max_context_length_slider.change(
729
  set_token_upper_limit, [current_model, max_context_length_slider], None)
730
- temperature_slider.change(
731
  set_temperature, [current_model, temperature_slider], None)
732
- top_p_slider.change(set_top_p, [current_model, top_p_slider], None)
733
- n_choices_slider.change(
734
  set_n_choices, [current_model, n_choices_slider], None)
735
- stop_sequence_txt.change(
736
  set_stop_sequence, [current_model, stop_sequence_txt], None)
737
- max_generation_slider.change(
738
  set_max_tokens, [current_model, max_generation_slider], None)
739
- presence_penalty_slider.change(
740
  set_presence_penalty, [current_model, presence_penalty_slider], None)
741
- frequency_penalty_slider.change(
742
  set_frequency_penalty, [current_model, frequency_penalty_slider], None)
743
- logit_bias_txt.change(
744
  set_logit_bias, [current_model, logit_bias_txt], None)
745
- user_identifier_txt.change(set_user_identifier, [
746
  current_model, user_identifier_txt], None)
747
 
748
  default_btn.click(
@@ -784,7 +782,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
784
  historySelectBtn.click( # This is an experimental feature... Not actually used.
785
  fn=load_chat_history,
786
  inputs=[current_model, historySelectList],
787
- outputs=[saveFileName, systemPromptTxt, chatbot],
788
  _js='(a,b)=>{return bgSelectHistory(a,b);}'
789
  )
790
 
 
499
  model_name=MODELS[DEFAULT_MODEL], access_key=my_api_key)[0]
500
  current_model.set_user_identifier(user_name)
501
  if not hide_history_when_not_logged_in or user_name:
502
+ loaded_stuff = current_model.auto_load()
503
  else:
504
+ loaded_stuff = [gr.update(), gr.update(), gr.Chatbot.update(label=MODELS[DEFAULT_MODEL]), current_model.single_turn, current_model.temperature, current_model.top_p, current_model.n_choices, current_model.stop_sequence, current_model.token_upper_limit, current_model.max_generation_token, current_model.presence_penalty, current_model.frequency_penalty, current_model.logit_bias, current_model.user_identifier]
505
+ return user_info, user_name, current_model, toggle_like_btn_visibility(DEFAULT_MODEL), *loaded_stuff, init_history_list(user_name)
 
 
506
  demo.load(create_greeting, inputs=None, outputs=[
507
+ user_info, user_name, current_model, like_dislike_area, saveFileName, systemPromptTxt, chatbot, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt, historySelectList], api_name="load")
508
  chatgpt_predict_args = dict(
509
  fn=predict,
510
  inputs=[
 
548
  load_history_from_file_args = dict(
549
  fn=load_chat_history,
550
  inputs=[current_model, historySelectList],
551
+ outputs=[saveFileName, systemPromptTxt, chatbot, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt],
552
  )
553
 
554
  refresh_history_args = dict(
 
585
  emptyBtn.click(
586
  reset,
587
  inputs=[current_model, retain_system_prompt_checkbox],
588
+ outputs=[chatbot, status_display, historySelectList, systemPromptTxt, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt],
589
  show_progress=True,
590
  _js='(a,b)=>{return clearChatbot(a,b);}',
591
  )
 
691
  )
692
  historySelectList.input(**load_history_from_file_args)
693
  uploadFileBtn.upload(upload_chat_history, [current_model, uploadFileBtn], [
694
+ saveFileName, systemPromptTxt, chatbot, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt]).then(**refresh_history_args)
695
  historyDownloadBtn.click(None, [
696
  user_name, historySelectList], None, _js='(a,b)=>{return downloadHistory(a,b,".json");}')
697
  historyMarkdownDownloadBtn.click(None, [
 
723
  cancel_all_jobs, [], [openai_train_status], show_progress=True)
724
 
725
  # Advanced
726
+ max_context_length_slider.input(
727
  set_token_upper_limit, [current_model, max_context_length_slider], None)
728
+ temperature_slider.input(
729
  set_temperature, [current_model, temperature_slider], None)
730
+ top_p_slider.input(set_top_p, [current_model, top_p_slider], None)
731
+ n_choices_slider.input(
732
  set_n_choices, [current_model, n_choices_slider], None)
733
+ stop_sequence_txt.input(
734
  set_stop_sequence, [current_model, stop_sequence_txt], None)
735
+ max_generation_slider.input(
736
  set_max_tokens, [current_model, max_generation_slider], None)
737
+ presence_penalty_slider.input(
738
  set_presence_penalty, [current_model, presence_penalty_slider], None)
739
+ frequency_penalty_slider.input(
740
  set_frequency_penalty, [current_model, frequency_penalty_slider], None)
741
+ logit_bias_txt.input(
742
  set_logit_bias, [current_model, logit_bias_txt], None)
743
+ user_identifier_txt.input(set_user_identifier, [
744
  current_model, user_identifier_txt], None)
745
 
746
  default_btn.click(
 
782
  historySelectBtn.click( # This is an experimental feature... Not actually used.
783
  fn=load_chat_history,
784
  inputs=[current_model, historySelectList],
785
+ outputs=[saveFileName, systemPromptTxt, chatbot, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt],
786
  _js='(a,b)=>{return bgSelectHistory(a,b);}'
787
  )
788
 
modules/models/OpenAI.py CHANGED
@@ -32,6 +32,7 @@ class OpenAIClient(BaseLLMModel):
32
  system_prompt=system_prompt,
33
  user=user_name
34
  )
 
35
  self.api_key = api_key
36
  self.need_api_key = True
37
  self._refresh_header()
 
32
  system_prompt=system_prompt,
33
  user=user_name
34
  )
35
+ logging.info(f"TEMPERATURE: {self.temperature}")
36
  self.api_key = api_key
37
  self.need_api_key = True
38
  self._refresh_header()
modules/models/base_model.py CHANGED
@@ -214,6 +214,7 @@ class BaseLLMModel:
214
  frequency_penalty=0,
215
  logit_bias=None,
216
  user="",
 
217
  ) -> None:
218
  self.history = []
219
  self.all_token_counts = []
@@ -230,10 +231,21 @@ class BaseLLMModel:
230
  self.system_prompt = system_prompt
231
  self.api_key = None
232
  self.need_api_key = False
233
- self.single_turn = False
234
  self.history_file_path = get_first_history_name(user)
235
  self.user_name = user
236
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  self.temperature = temperature
238
  self.top_p = top_p
239
  self.n_choices = n_choices
@@ -251,7 +263,9 @@ class BaseLLMModel:
251
  Conversations are stored in self.history, with the most recent question in OpenAI format.
252
  Should return a generator that yields the next word (str) in the answer.
253
  """
254
- logging.warning("Stream prediction is not implemented. Using at once prediction instead.")
 
 
255
  response, _ = self.get_answer_at_once()
256
  yield response
257
 
@@ -749,11 +763,34 @@ class BaseLLMModel:
749
  history_name = self.history_file_path[:-5]
750
  choices = [history_name] + get_history_names(self.user_name)
751
  system_prompt = self.system_prompt if remain_system_prompt else ""
 
 
 
 
 
 
 
 
 
 
 
 
752
  return (
753
  [],
754
  self.token_message([0]),
755
  gr.Radio.update(choices=choices, value=history_name),
756
  system_prompt,
 
 
 
 
 
 
 
 
 
 
 
757
  )
758
 
759
  def delete_first_conversation(self):
@@ -877,30 +914,67 @@ class BaseLLMModel:
877
  pass
878
  if len(saved_json["chatbot"]) < len(saved_json["history"]) // 2:
879
  logging.info("Trimming corrupted history...")
880
- saved_json["history"] = saved_json["history"][-len(saved_json["chatbot"]) :]
 
 
881
  logging.info(f"Trimmed history: {saved_json['history']}")
882
  logging.debug(f"{self.user_name} 加载对话历史完毕")
883
  self.history = saved_json["history"]
884
- self.single_turn = saved_json.get("single_turn", False)
885
- self.temperature = saved_json.get("temperature", 1.0)
886
- self.top_p = saved_json.get("top_p", None)
887
- self.n_choices = saved_json.get("n_choices", 1)
888
- self.stop_sequence = saved_json.get("stop_sequence", None)
889
- self.max_generation_token = saved_json.get("max_generation_token", None)
890
- self.presence_penalty = saved_json.get("presence_penalty", 0)
891
- self.frequency_penalty = saved_json.get("frequency_penalty", 0)
892
- self.logit_bias = saved_json.get("logit_bias", None)
 
 
 
 
 
 
 
 
 
893
  self.user_identifier = saved_json.get("user_identifier", self.user_name)
894
- self.metadata = saved_json.get("metadata", {})
895
  return (
896
- os.path.basename(self.history_file_path),
897
  saved_json["system"],
898
  saved_json["chatbot"],
 
 
 
 
 
 
 
 
 
 
 
899
  )
900
  except:
901
  # 没有对话历史或者对话历史解析失败
902
  logging.info(f"没有找到对话历史记录 {self.history_file_path}")
903
- return self.history_file_path, "", []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
  def delete_chat_history(self, filename):
906
  if filename == "CANCELED":
@@ -910,9 +984,7 @@ class BaseLLMModel:
910
  if not filename.endswith(".json"):
911
  filename += ".json"
912
  if filename == os.path.basename(filename):
913
- history_file_path = os.path.join(
914
- HISTORY_DIR, self.user_name, filename
915
- )
916
  else:
917
  history_file_path = filename
918
  md_history_file_path = history_file_path[:-5] + ".md"
@@ -934,9 +1006,7 @@ class BaseLLMModel:
934
  self.history_file_path = new_auto_history_filename(self.user_name)
935
  else:
936
  self.history_file_path = filepath
937
- filename, system_prompt, chatbot = self.load_chat_history()
938
- filename = filename[:-5]
939
- return filename, system_prompt, chatbot
940
 
941
  def like(self):
942
  """like the last response, implement if needed"""
 
214
  frequency_penalty=0,
215
  logit_bias=None,
216
  user="",
217
+ single_turn=False,
218
  ) -> None:
219
  self.history = []
220
  self.all_token_counts = []
 
231
  self.system_prompt = system_prompt
232
  self.api_key = None
233
  self.need_api_key = False
 
234
  self.history_file_path = get_first_history_name(user)
235
  self.user_name = user
236
 
237
+ self.default_single_turn = single_turn
238
+ self.default_temperature = temperature
239
+ self.default_top_p = top_p
240
+ self.default_n_choices = n_choices
241
+ self.default_stop_sequence = stop
242
+ self.default_max_generation_token = max_generation_token
243
+ self.default_presence_penalty = presence_penalty
244
+ self.default_frequency_penalty = frequency_penalty
245
+ self.default_logit_bias = logit_bias
246
+ self.default_user_identifier = user
247
+
248
+ self.single_turn = single_turn
249
  self.temperature = temperature
250
  self.top_p = top_p
251
  self.n_choices = n_choices
 
263
  Conversations are stored in self.history, with the most recent question in OpenAI format.
264
  Should return a generator that yields the next word (str) in the answer.
265
  """
266
+ logging.warning(
267
+ "Stream prediction is not implemented. Using at once prediction instead."
268
+ )
269
  response, _ = self.get_answer_at_once()
270
  yield response
271
 
 
763
  history_name = self.history_file_path[:-5]
764
  choices = [history_name] + get_history_names(self.user_name)
765
  system_prompt = self.system_prompt if remain_system_prompt else ""
766
+
767
+ self.single_turn = self.default_single_turn
768
+ self.temperature = self.default_temperature
769
+ self.top_p = self.default_top_p
770
+ self.n_choices = self.default_n_choices
771
+ self.stop_sequence = self.default_stop_sequence
772
+ self.max_generation_token = self.default_max_generation_token
773
+ self.presence_penalty = self.default_presence_penalty
774
+ self.frequency_penalty = self.default_frequency_penalty
775
+ self.logit_bias = self.default_logit_bias
776
+ self.user_identifier = self.default_user_identifier
777
+
778
  return (
779
  [],
780
  self.token_message([0]),
781
  gr.Radio.update(choices=choices, value=history_name),
782
  system_prompt,
783
+ self.single_turn,
784
+ self.temperature,
785
+ self.top_p,
786
+ self.n_choices,
787
+ self.stop_sequence,
788
+ self.token_upper_limit,
789
+ self.max_generation_token,
790
+ self.presence_penalty,
791
+ self.frequency_penalty,
792
+ self.logit_bias,
793
+ self.user_identifier,
794
  )
795
 
796
  def delete_first_conversation(self):
 
914
  pass
915
  if len(saved_json["chatbot"]) < len(saved_json["history"]) // 2:
916
  logging.info("Trimming corrupted history...")
917
+ saved_json["history"] = saved_json["history"][
918
+ -len(saved_json["chatbot"]) :
919
+ ]
920
  logging.info(f"Trimmed history: {saved_json['history']}")
921
  logging.debug(f"{self.user_name} 加载对话历史完毕")
922
  self.history = saved_json["history"]
923
+ self.single_turn = saved_json.get("single_turn", self.single_turn)
924
+ self.temperature = saved_json.get("temperature", self.temperature)
925
+ self.top_p = saved_json.get("top_p", self.top_p)
926
+ self.n_choices = saved_json.get("n_choices", self.n_choices)
927
+ self.stop_sequence = saved_json.get("stop_sequence", self.stop_sequence)
928
+ self.token_upper_limit = saved_json.get(
929
+ "token_upper_limit", self.token_upper_limit
930
+ )
931
+ self.max_generation_token = saved_json.get(
932
+ "max_generation_token", self.max_generation_token
933
+ )
934
+ self.presence_penalty = saved_json.get(
935
+ "presence_penalty", self.presence_penalty
936
+ )
937
+ self.frequency_penalty = saved_json.get(
938
+ "frequency_penalty", self.frequency_penalty
939
+ )
940
+ self.logit_bias = saved_json.get("logit_bias", self.logit_bias)
941
  self.user_identifier = saved_json.get("user_identifier", self.user_name)
942
+ self.metadata = saved_json.get("metadata", self.metadata)
943
  return (
944
+ os.path.basename(self.history_file_path)[:-5],
945
  saved_json["system"],
946
  saved_json["chatbot"],
947
+ self.single_turn,
948
+ self.temperature,
949
+ self.top_p,
950
+ self.n_choices,
951
+ self.stop_sequence,
952
+ self.token_upper_limit,
953
+ self.max_generation_token,
954
+ self.presence_penalty,
955
+ self.frequency_penalty,
956
+ self.logit_bias,
957
+ self.user_identifier,
958
  )
959
  except:
960
  # 没有对话历史或者对话历史解析失败
961
  logging.info(f"没有找到对话历史记录 {self.history_file_path}")
962
+ return (
963
+ os.path.basename(self.history_file_path),
964
+ "",
965
+ [],
966
+ self.single_turn,
967
+ self.temperature,
968
+ self.top_p,
969
+ self.n_choices,
970
+ self.stop_sequence,
971
+ self.token_upper_limit,
972
+ self.max_generation_token,
973
+ self.presence_penalty,
974
+ self.frequency_penalty,
975
+ self.logit_bias,
976
+ self.user_identifier,
977
+ )
978
 
979
  def delete_chat_history(self, filename):
980
  if filename == "CANCELED":
 
984
  if not filename.endswith(".json"):
985
  filename += ".json"
986
  if filename == os.path.basename(filename):
987
+ history_file_path = os.path.join(HISTORY_DIR, self.user_name, filename)
 
 
988
  else:
989
  history_file_path = filename
990
  md_history_file_path = history_file_path[:-5] + ".md"
 
1006
  self.history_file_path = new_auto_history_filename(self.user_name)
1007
  else:
1008
  self.history_file_path = filepath
1009
+ return self.load_chat_history()
 
 
1010
 
1011
  def like(self):
1012
  """like the last response, implement if needed"""
modules/utils.py CHANGED
@@ -389,6 +389,7 @@ def save_file(filename, model, chatbot):
389
  "top_p": model.top_p,
390
  "n_choices": model.n_choices,
391
  "stop_sequence": model.stop_sequence,
 
392
  "max_generation_token": model.max_generation_token,
393
  "presence_penalty": model.presence_penalty,
394
  "frequency_penalty": model.frequency_penalty,
 
389
  "top_p": model.top_p,
390
  "n_choices": model.n_choices,
391
  "stop_sequence": model.stop_sequence,
392
+ "token_upper_limit": model.token_upper_limit,
393
  "max_generation_token": model.max_generation_token,
394
  "presence_penalty": model.presence_penalty,
395
  "frequency_penalty": model.frequency_penalty,