MZhaovo commited on
Commit
90ce14b
·
1 Parent(s): 0a2a419

将本地LLM定义为全局变量,防止多次调用。make class great again

Browse files
Files changed (4) hide show
  1. ChuanhuChatbot.py +38 -32
  2. modules/models.py +138 -222
  3. modules/presets.py +5 -0
  4. modules/utils.py +76 -0
ChuanhuChatbot.py CHANGED
@@ -10,7 +10,7 @@ from modules.config import *
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
- from modules.models import ModelManager
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
@@ -18,11 +18,14 @@ PromptHelper.compact_text_chunks = compact_text_chunks
18
  with open("assets/custom.css", "r", encoding="utf-8") as f:
19
  customCSS = f.read()
20
 
 
 
 
21
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
22
  user_name = gr.State("")
23
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
24
  user_question = gr.State("")
25
- current_model = gr.State(ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key))
26
 
27
  topic = gr.State("未命名对话历史记录")
28
 
@@ -264,8 +267,9 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
264
  gr.Markdown(CHUANHU_DESCRIPTION)
265
  gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
266
  chatgpt_predict_args = dict(
267
- fn=current_model.value.predict,
268
  inputs=[
 
269
  user_question,
270
  chatbot,
271
  use_streaming_checkbox,
@@ -297,18 +301,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
297
  )
298
 
299
  get_usage_args = dict(
300
- fn=current_model.value.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
301
  )
302
 
303
  load_history_from_file_args = dict(
304
- fn=current_model.value.load_chat_history,
305
- inputs=[historyFileSelectDropdown, chatbot, user_name],
306
  outputs=[saveFileName, systemPromptTxt, chatbot]
307
  )
308
 
309
 
310
  # Chatbot
311
- cancelBtn.click(current_model.value.interrupt, [], [])
312
 
313
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
314
  user_input.submit(**get_usage_args)
@@ -317,15 +321,17 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
317
  submitBtn.click(**get_usage_args)
318
 
319
  emptyBtn.click(
320
- current_model.value.reset,
 
321
  outputs=[chatbot, status_display],
322
  show_progress=True,
323
  )
324
  emptyBtn.click(**reset_textbox_args)
325
 
326
  retryBtn.click(**start_outputing_args).then(
327
- current_model.value.retry,
328
  [
 
329
  chatbot,
330
  use_streaming_checkbox,
331
  use_websearch_checkbox,
@@ -338,14 +344,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
338
  retryBtn.click(**get_usage_args)
339
 
340
  delFirstBtn.click(
341
- current_model.value.delete_first_conversation,
342
- None,
343
  [status_display],
344
  )
345
 
346
  delLastBtn.click(
347
- current_model.value.delete_last_conversation,
348
- [chatbot],
349
  [chatbot, status_display],
350
  show_progress=False
351
  )
@@ -353,14 +359,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
353
  two_column.change(update_doc_config, [two_column], None)
354
 
355
  # LLM Models
356
- keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
357
  keyTxt.submit(**get_usage_args)
358
- single_turn_checkbox.change(current_model.value.set_single_turn, single_turn_checkbox, None)
359
- model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
360
- lora_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
361
 
362
  # Template
363
- systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
364
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
365
  templateFileSelectDropdown.change(
366
  load_template,
@@ -377,15 +383,15 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
377
 
378
  # S&L
379
  saveHistoryBtn.click(
380
- current_model.value.save_chat_history,
381
- [saveFileName, chatbot, user_name],
382
  downloadFile,
383
  show_progress=True,
384
  )
385
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
386
  exportMarkdownBtn.click(
387
- current_model.value.export_markdown,
388
- [saveFileName, chatbot, user_name],
389
  downloadFile,
390
  show_progress=True,
391
  )
@@ -394,16 +400,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
394
  downloadFile.change(**load_history_from_file_args)
395
 
396
  # Advanced
397
- max_context_length_slider.change(current_model.value.set_token_upper_limit, [max_context_length_slider], None)
398
- temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
399
- top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
400
- n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
401
- stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
402
- max_generation_slider.change(current_model.value.set_max_tokens, [max_generation_slider], None)
403
- presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
404
- frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
405
- logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
406
- user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
407
 
408
  default_btn.click(
409
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
 
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
+ from modules.models import get_model
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
 
18
  with open("assets/custom.css", "r", encoding="utf-8") as f:
19
  customCSS = f.read()
20
 
21
+ def create_new_model():
22
+ return get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
23
+
24
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
25
  user_name = gr.State("")
26
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
27
  user_question = gr.State("")
28
+ current_model = gr.State(create_new_model)
29
 
30
  topic = gr.State("未命名对话历史记录")
31
 
 
267
  gr.Markdown(CHUANHU_DESCRIPTION)
268
  gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
269
  chatgpt_predict_args = dict(
270
+ fn=predict,
271
  inputs=[
272
+ current_model,
273
  user_question,
274
  chatbot,
275
  use_streaming_checkbox,
 
301
  )
302
 
303
  get_usage_args = dict(
304
+ fn=billing_info, inputs=[current_model], outputs=[usageTxt], show_progress=False
305
  )
306
 
307
  load_history_from_file_args = dict(
308
+ fn=load_chat_history,
309
+ inputs=[current_model, historyFileSelectDropdown, chatbot, user_name],
310
  outputs=[saveFileName, systemPromptTxt, chatbot]
311
  )
312
 
313
 
314
  # Chatbot
315
+ cancelBtn.click(interrupt, [current_model], [])
316
 
317
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
318
  user_input.submit(**get_usage_args)
 
321
  submitBtn.click(**get_usage_args)
322
 
323
  emptyBtn.click(
324
+ reset,
325
+ inputs=[current_model],
326
  outputs=[chatbot, status_display],
327
  show_progress=True,
328
  )
329
  emptyBtn.click(**reset_textbox_args)
330
 
331
  retryBtn.click(**start_outputing_args).then(
332
+ retry,
333
  [
334
+ current_model,
335
  chatbot,
336
  use_streaming_checkbox,
337
  use_websearch_checkbox,
 
344
  retryBtn.click(**get_usage_args)
345
 
346
  delFirstBtn.click(
347
+ delete_first_conversation,
348
+ [current_model],
349
  [status_display],
350
  )
351
 
352
  delLastBtn.click(
353
+ delete_last_conversation,
354
+ [current_model, chatbot],
355
  [chatbot, status_display],
356
  show_progress=False
357
  )
 
359
  two_column.change(update_doc_config, [two_column], None)
360
 
361
  # LLM Models
362
+ keyTxt.change(set_key, [current_model, keyTxt], [status_display]).then(**get_usage_args)
363
  keyTxt.submit(**get_usage_args)
364
+ single_turn_checkbox.change(set_single_turn, [current_model, single_turn_checkbox], None)
365
+ model_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display, lora_select_dropdown], show_progress=True)
366
+ lora_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display], show_progress=True)
367
 
368
  # Template
369
+ systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
370
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
371
  templateFileSelectDropdown.change(
372
  load_template,
 
383
 
384
  # S&L
385
  saveHistoryBtn.click(
386
+ save_chat_history,
387
+ [current_model, saveFileName, chatbot, user_name],
388
  downloadFile,
389
  show_progress=True,
390
  )
391
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
392
  exportMarkdownBtn.click(
393
+ export_markdown,
394
+ [current_model, saveFileName, chatbot, user_name],
395
  downloadFile,
396
  show_progress=True,
397
  )
 
400
  downloadFile.change(**load_history_from_file_args)
401
 
402
  # Advanced
403
+ max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
404
+ temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
405
+ top_p_slider.change(set_top_p, [current_model, top_p_slider], None)
406
+ n_choices_slider.change(set_n_choices, [current_model, n_choices_slider], None)
407
+ stop_sequence_txt.change(set_stop_sequence, [current_model, stop_sequence_txt], None)
408
+ max_generation_slider.change(set_max_tokens, [current_model, max_generation_slider], None)
409
+ presence_penalty_slider.change(set_presence_penalty, [current_model, presence_penalty_slider], None)
410
+ frequency_penalty_slider.change(set_frequency_penalty, [current_model, frequency_penalty_slider], None)
411
+ logit_bias_txt.change(set_logit_bias, [current_model, logit_bias_txt], None)
412
+ user_identifier_txt.change(set_user_identifier, [current_model, user_identifier_txt], None)
413
 
414
  default_btn.click(
415
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
modules/models.py CHANGED
@@ -207,51 +207,52 @@ class OpenAIClient(BaseLLMModel):
207
  continue
208
  if error_msg:
209
  raise Exception(error_msg)
210
-
211
 
212
  class ChatGLM_Client(BaseLLMModel):
213
  def __init__(self, model_name) -> None:
214
  super().__init__(model_name=model_name)
215
  from transformers import AutoTokenizer, AutoModel
216
  import torch
217
-
218
- system_name = platform.system()
219
- model_path=None
220
- if os.path.exists("models"):
221
- model_dirs = os.listdir("models")
222
- if model_name in model_dirs:
223
- model_path = f"models/{model_name}"
224
- if model_path is not None:
225
- model_source = model_path
226
- else:
227
- model_source = f"THUDM/{model_name}"
228
- self.tokenizer = AutoTokenizer.from_pretrained(
229
- model_source, trust_remote_code=True
230
- )
231
- quantified = False
232
- if "int4" in model_name:
233
- quantified = True
234
- if quantified:
235
- model = AutoModel.from_pretrained(
236
- model_source, trust_remote_code=True
237
- ).half()
238
- else:
239
- model = AutoModel.from_pretrained(
240
  model_source, trust_remote_code=True
241
- ).half()
242
- if torch.cuda.is_available():
243
- # run on CUDA
244
- logging.info("CUDA is available, using CUDA")
245
- model = model.cuda()
246
- # mps加速还存在一些问题,暂时不使用
247
- elif system_name == "Darwin" and model_path is not None and not quantified:
248
- logging.info("Running on macOS, using MPS")
249
- # running on macOS and model already downloaded
250
- model = model.to("mps")
251
- else:
252
- logging.info("GPU is not available, using CPU")
253
- model = model.eval()
254
- self.model = model
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def _get_glm_style_input(self):
257
  history = [x["content"] for x in self.history]
@@ -265,13 +266,13 @@ class ChatGLM_Client(BaseLLMModel):
265
 
266
  def get_answer_at_once(self):
267
  history, query = self._get_glm_style_input()
268
- response, _ = self.model.chat(self.tokenizer, query, history=history)
269
  return response, len(response)
270
 
271
  def get_answer_stream_iter(self):
272
  history, query = self._get_glm_style_input()
273
- for response, history in self.model.stream_chat(
274
- self.tokenizer,
275
  query,
276
  history,
277
  max_length=self.token_upper_limit,
@@ -292,53 +293,53 @@ class LLaMA_Client(BaseLLMModel):
292
  from lmflow.pipeline.auto_pipeline import AutoPipeline
293
  from lmflow.models.auto_model import AutoModel
294
  from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
295
- model_path = None
296
- if os.path.exists("models"):
297
- model_dirs = os.listdir("models")
298
- if model_name in model_dirs:
299
- model_path = f"models/{model_name}"
300
- if model_path is not None:
301
- model_source = model_path
302
- else:
303
- model_source = f"decapoda-research/{model_name}"
304
- # raise Exception(f"models目录下没有这个模型: {model_name}")
305
- if lora_path is not None:
306
- lora_path = f"lora/{lora_path}"
307
  self.max_generation_token = 1000
308
- pipeline_name = "inferencer"
309
- model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
310
- pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
311
-
312
- with open(pipeline_args.deepspeed, "r") as f:
313
- ds_config = json.load(f)
314
-
315
- self.model = AutoModel.get_model(
316
- model_args,
317
- tune_strategy="none",
318
- ds_config=ds_config,
319
- )
320
-
321
  # We don't need input data
322
  data_args = DatasetArguments(dataset_path=None)
323
  self.dataset = Dataset(data_args)
324
 
325
- self.inferencer = AutoPipeline.get_pipeline(
326
- pipeline_name=pipeline_name,
327
- model_args=model_args,
328
- data_args=data_args,
329
- pipeline_args=pipeline_args,
330
- )
331
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  # Chats
333
- model_name = model_args.model_name_or_path
334
- if model_args.lora_model_path is not None:
335
- model_name += f" + {model_args.lora_model_path}"
336
 
337
  # context = (
338
  # "You are a helpful assistant who follows the given instructions"
339
  # " unconditionally."
340
  # )
341
- self.end_string = "\n\n"
342
 
343
  def _get_llama_style_input(self):
344
  history = []
@@ -358,8 +359,8 @@ class LLaMA_Client(BaseLLMModel):
358
  {"type": "text_only", "instances": [{"text": context}]}
359
  )
360
 
361
- output_dataset = self.inferencer.inference(
362
- model=self.model,
363
  dataset=input_dataset,
364
  max_new_tokens=self.max_generation_token,
365
  temperature=self.temperature,
@@ -376,8 +377,8 @@ class LLaMA_Client(BaseLLMModel):
376
  input_dataset = self.dataset.from_dict(
377
  {"type": "text_only", "instances": [{"text": context+partial_text}]}
378
  )
379
- output_dataset = self.inferencer.inference(
380
- model=self.model,
381
  dataset=input_dataset,
382
  max_new_tokens=step,
383
  temperature=self.temperature,
@@ -389,147 +390,62 @@ class LLaMA_Client(BaseLLMModel):
389
  yield partial_text
390
 
391
 
392
- class ModelManager:
393
- def __init__(self, **kwargs) -> None:
394
- self.model = None
395
- self.get_model(**kwargs)
396
-
397
- def get_model(
398
- self,
399
- model_name,
400
- lora_model_path=None,
401
- access_key=None,
402
- temperature=None,
403
- top_p=None,
404
- system_prompt=None,
405
- ) -> BaseLLMModel:
406
- msg = f"模型设置为了: {model_name}"
407
- model_type = ModelType.get_type(model_name)
408
- lora_selector_visibility = False
409
- lora_choices = []
410
- dont_change_lora_selector = False
411
- if model_type != ModelType.OpenAI:
412
- config.local_embedding = True
413
- del self.model
414
- model = None
415
- try:
416
- if model_type == ModelType.OpenAI:
417
- logging.info(f"正在加载OpenAI模型: {model_name}")
418
- model = OpenAIClient(
419
- model_name=model_name,
420
- api_key=access_key,
421
- system_prompt=system_prompt,
422
- temperature=temperature,
423
- top_p=top_p,
424
- )
425
- elif model_type == ModelType.ChatGLM:
426
- logging.info(f"正在加载ChatGLM模型: {model_name}")
427
- model = ChatGLM_Client(model_name)
428
- elif model_type == ModelType.LLaMA and lora_model_path == "":
429
- msg = f"现在请为 {model_name} 选择LoRA模型"
430
- logging.info(msg)
431
- lora_selector_visibility = True
432
- if os.path.isdir("lora"):
433
- lora_choices = get_file_names("lora", plain=True, filetypes=[""])
434
- lora_choices = ["No LoRA"] + lora_choices
435
- elif model_type == ModelType.LLaMA and lora_model_path != "":
436
- logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
437
- dont_change_lora_selector = True
438
- if lora_model_path == "No LoRA":
439
- lora_model_path = None
440
- msg += " + No LoRA"
441
- else:
442
- msg += f" + {lora_model_path}"
443
- model = LLaMA_Client(model_name, lora_model_path)
444
- elif model_type == ModelType.Unknown:
445
- raise ValueError(f"未知模型: {model_name}")
446
  logging.info(msg)
447
- except Exception as e:
448
- logging.error(e)
449
- msg = f"{STANDARD_ERROR_MSG}: {e}"
450
- self.model = model
451
- if dont_change_lora_selector:
452
- return msg
453
- else:
454
- return msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
455
-
456
- def predict(self, *args):
457
- iter = self.model.predict(*args)
458
- for i in iter:
459
- yield i
460
-
461
- def billing_info(self):
462
- return self.model.billing_info()
463
-
464
- def set_key(self, *args):
465
- return self.model.set_key(*args)
466
-
467
- def load_chat_history(self, *args):
468
- return self.model.load_chat_history(*args)
469
-
470
- def interrupt(self, *args):
471
- return self.model.interrupt(*args)
472
-
473
- def reset(self, *args):
474
- return self.model.reset(*args)
475
-
476
- def retry(self, *args):
477
- iter = self.model.retry(*args)
478
- for i in iter:
479
- yield i
480
-
481
- def delete_first_conversation(self, *args):
482
- return self.model.delete_first_conversation(*args)
483
-
484
- def delete_last_conversation(self, *args):
485
- return self.model.delete_last_conversation(*args)
486
-
487
- def set_system_prompt(self, *args):
488
- return self.model.set_system_prompt(*args)
489
-
490
- def save_chat_history(self, *args):
491
- return self.model.save_chat_history(*args)
492
-
493
- def export_markdown(self, *args):
494
- return self.model.export_markdown(*args)
495
-
496
- def load_chat_history(self, *args):
497
- return self.model.load_chat_history(*args)
498
-
499
- def set_token_upper_limit(self, *args):
500
- return self.model.set_token_upper_limit(*args)
501
-
502
- def set_temperature(self, *args):
503
- self.model.set_temperature(*args)
504
-
505
- def set_top_p(self, *args):
506
- self.model.set_top_p(*args)
507
-
508
- def set_n_choices(self, *args):
509
- self.model.set_n_choices(*args)
510
-
511
- def set_stop_sequence(self, *args):
512
- self.model.set_stop_sequence(*args)
513
-
514
- def set_max_tokens(self, *args):
515
- self.model.set_max_tokens(*args)
516
-
517
- def set_presence_penalty(self, *args):
518
- self.model.set_presence_penalty(*args)
519
-
520
- def set_frequency_penalty(self, *args):
521
- self.model.set_frequency_penalty(*args)
522
-
523
- def set_logit_bias(self, *args):
524
- self.model.set_logit_bias(*args)
525
-
526
- def set_user_identifier(self, *args):
527
- self.model.set_user_identifier(*args)
528
-
529
- def set_single_turn(self, *args):
530
- self.model.set_single_turn(*args)
531
-
532
-
533
 
534
 
535
  if __name__ == "__main__":
@@ -538,7 +454,7 @@ if __name__ == "__main__":
538
  # set logging level to debug
539
  logging.basicConfig(level=logging.DEBUG)
540
  # client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
541
- client = ModelManager(model_name="chatglm-6b-int4")
542
  chatbot = []
543
  stream = False
544
  # 测试账单功能
 
207
  continue
208
  if error_msg:
209
  raise Exception(error_msg)
210
+
211
 
212
  class ChatGLM_Client(BaseLLMModel):
213
  def __init__(self, model_name) -> None:
214
  super().__init__(model_name=model_name)
215
  from transformers import AutoTokenizer, AutoModel
216
  import torch
217
+ global CHATGLM_TOKENIZER, CHATGLM_MODEL
218
+ if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
219
+ system_name = platform.system()
220
+ model_path=None
221
+ if os.path.exists("models"):
222
+ model_dirs = os.listdir("models")
223
+ if model_name in model_dirs:
224
+ model_path = f"models/{model_name}"
225
+ if model_path is not None:
226
+ model_source = model_path
227
+ else:
228
+ model_source = f"THUDM/{model_name}"
229
+ CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
 
 
 
 
 
 
 
 
 
 
230
  model_source, trust_remote_code=True
231
+ )
232
+ quantified = False
233
+ if "int4" in model_name:
234
+ quantified = True
235
+ if quantified:
236
+ model = AutoModel.from_pretrained(
237
+ model_source, trust_remote_code=True
238
+ ).half()
239
+ else:
240
+ model = AutoModel.from_pretrained(
241
+ model_source, trust_remote_code=True
242
+ ).half()
243
+ if torch.cuda.is_available():
244
+ # run on CUDA
245
+ logging.info("CUDA is available, using CUDA")
246
+ model = model.cuda()
247
+ # mps加速还存在一些问题,暂时不使用
248
+ elif system_name == "Darwin" and model_path is not None and not quantified:
249
+ logging.info("Running on macOS, using MPS")
250
+ # running on macOS and model already downloaded
251
+ model = model.to("mps")
252
+ else:
253
+ logging.info("GPU is not available, using CPU")
254
+ model = model.eval()
255
+ CHATGLM_MODEL = model
256
 
257
  def _get_glm_style_input(self):
258
  history = [x["content"] for x in self.history]
 
266
 
267
  def get_answer_at_once(self):
268
  history, query = self._get_glm_style_input()
269
+ response, _ = CHATGLM_MODEL.chat(CHATGLM_TOKENIZER, query, history=history)
270
  return response, len(response)
271
 
272
  def get_answer_stream_iter(self):
273
  history, query = self._get_glm_style_input()
274
+ for response, history in CHATGLM_MODEL.stream_chat(
275
+ CHATGLM_TOKENIZER,
276
  query,
277
  history,
278
  max_length=self.token_upper_limit,
 
293
  from lmflow.pipeline.auto_pipeline import AutoPipeline
294
  from lmflow.models.auto_model import AutoModel
295
  from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
296
+
 
 
 
 
 
 
 
 
 
 
 
297
  self.max_generation_token = 1000
298
+ self.end_string = "\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
299
  # We don't need input data
300
  data_args = DatasetArguments(dataset_path=None)
301
  self.dataset = Dataset(data_args)
302
 
303
+ global LLAMA_MODEL, LLAMA_INFERENCER
304
+ if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
305
+ model_path = None
306
+ if os.path.exists("models"):
307
+ model_dirs = os.listdir("models")
308
+ if model_name in model_dirs:
309
+ model_path = f"models/{model_name}"
310
+ if model_path is not None:
311
+ model_source = model_path
312
+ else:
313
+ model_source = f"decapoda-research/{model_name}"
314
+ # raise Exception(f"models目录下没有这个模型: {model_name}")
315
+ if lora_path is not None:
316
+ lora_path = f"lora/{lora_path}"
317
+ model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
318
+ pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
319
+
320
+ with open(pipeline_args.deepspeed, "r") as f:
321
+ ds_config = json.load(f)
322
+ LLAMA_MODEL = AutoModel.get_model(
323
+ model_args,
324
+ tune_strategy="none",
325
+ ds_config=ds_config,
326
+ )
327
+ LLAMA_INFERENCER = AutoPipeline.get_pipeline(
328
+ pipeline_name="inferencer",
329
+ model_args=model_args,
330
+ data_args=data_args,
331
+ pipeline_args=pipeline_args,
332
+ )
333
  # Chats
334
+ # model_name = model_args.model_name_or_path
335
+ # if model_args.lora_model_path is not None:
336
+ # model_name += f" + {model_args.lora_model_path}"
337
 
338
  # context = (
339
  # "You are a helpful assistant who follows the given instructions"
340
  # " unconditionally."
341
  # )
342
+
343
 
344
  def _get_llama_style_input(self):
345
  history = []
 
359
  {"type": "text_only", "instances": [{"text": context}]}
360
  )
361
 
362
+ output_dataset = LLAMA_INFERENCER.inference(
363
+ model=LLAMA_MODEL,
364
  dataset=input_dataset,
365
  max_new_tokens=self.max_generation_token,
366
  temperature=self.temperature,
 
377
  input_dataset = self.dataset.from_dict(
378
  {"type": "text_only", "instances": [{"text": context+partial_text}]}
379
  )
380
+ output_dataset = LLAMA_INFERENCER.inference(
381
+ model=LLAMA_MODEL,
382
  dataset=input_dataset,
383
  max_new_tokens=step,
384
  temperature=self.temperature,
 
390
  yield partial_text
391
 
392
 
393
+ def get_model(
394
+ model_name,
395
+ lora_model_path=None,
396
+ access_key=None,
397
+ temperature=None,
398
+ top_p=None,
399
+ system_prompt=None,
400
+ ) -> BaseLLMModel:
401
+ msg = f"模型设置为了: {model_name}"
402
+ model_type = ModelType.get_type(model_name)
403
+ lora_selector_visibility = False
404
+ lora_choices = []
405
+ dont_change_lora_selector = False
406
+ if model_type != ModelType.OpenAI:
407
+ config.local_embedding = True
408
+ # del current_model.model
409
+ model = None
410
+ try:
411
+ if model_type == ModelType.OpenAI:
412
+ logging.info(f"正在加载OpenAI模型: {model_name}")
413
+ model = OpenAIClient(
414
+ model_name=model_name,
415
+ api_key=access_key,
416
+ system_prompt=system_prompt,
417
+ temperature=temperature,
418
+ top_p=top_p,
419
+ )
420
+ elif model_type == ModelType.ChatGLM:
421
+ logging.info(f"正在加载ChatGLM模型: {model_name}")
422
+ model = ChatGLM_Client(model_name)
423
+ elif model_type == ModelType.LLaMA and lora_model_path == "":
424
+ msg = f"现在请为 {model_name} 选择LoRA模型"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  logging.info(msg)
426
+ lora_selector_visibility = True
427
+ if os.path.isdir("lora"):
428
+ lora_choices = get_file_names("lora", plain=True, filetypes=[""])
429
+ lora_choices = ["No LoRA"] + lora_choices
430
+ elif model_type == ModelType.LLaMA and lora_model_path != "":
431
+ logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
432
+ dont_change_lora_selector = True
433
+ if lora_model_path == "No LoRA":
434
+ lora_model_path = None
435
+ msg += " + No LoRA"
436
+ else:
437
+ msg += f" + {lora_model_path}"
438
+ model = LLaMA_Client(model_name, lora_model_path)
439
+ elif model_type == ModelType.Unknown:
440
+ raise ValueError(f"未知模型: {model_name}")
441
+ logging.info(msg)
442
+ except Exception as e:
443
+ logging.error(e)
444
+ msg = f"{STANDARD_ERROR_MSG}: {e}"
445
+ if dont_change_lora_selector:
446
+ return model, msg
447
+ else:
448
+ return model, msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
 
451
  if __name__ == "__main__":
 
454
  # set logging level to debug
455
  logging.basicConfig(level=logging.DEBUG)
456
  # client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
457
+ client = get_model(model_name="chatglm-6b-int4")
458
  chatbot = []
459
  stream = False
460
  # 测试账单功能
modules/presets.py CHANGED
@@ -4,6 +4,11 @@ from pathlib import Path
4
 
5
  import gradio as gr
6
 
 
 
 
 
 
7
  # ChatGPT 设置
8
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
9
  API_HOST = "api.openai.com"
 
4
 
5
  import gradio as gr
6
 
7
+ CHATGLM_MODEL = None
8
+ CHATGLM_TOKENIZER = None
9
+ LLAMA_MODEL = None
10
+ LLAMA_INFERENCER = None
11
+
12
  # ChatGPT 设置
13
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
14
  API_HOST = "api.openai.com"
modules/utils.py CHANGED
@@ -33,6 +33,82 @@ if TYPE_CHECKING:
33
  class DataframeData(TypedDict):
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def count_token(message):
 
33
  class DataframeData(TypedDict):
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
36
+
37
+ def predict(current_model, *args):
38
+ iter = current_model.predict(*args)
39
+ for i in iter:
40
+ yield i
41
+
42
+ def billing_info(current_model):
43
+ return current_model.billing_info()
44
+
45
+ def set_key(current_model, *args):
46
+ return current_model.set_key(*args)
47
+
48
+ def load_chat_history(current_model, *args):
49
+ return current_model.load_chat_history(*args)
50
+
51
+ def interrupt(current_model, *args):
52
+ return current_model.interrupt(*args)
53
+
54
+ def reset(current_model, *args):
55
+ return current_model.reset(*args)
56
+
57
+ def retry(current_model, *args):
58
+ iter = current_model.retry(*args)
59
+ for i in iter:
60
+ yield i
61
+
62
+ def delete_first_conversation(current_model, *args):
63
+ return current_model.delete_first_conversation(*args)
64
+
65
+ def delete_last_conversation(current_model, *args):
66
+ return current_model.delete_last_conversation(*args)
67
+
68
+ def set_system_prompt(current_model, *args):
69
+ return current_model.set_system_prompt(*args)
70
+
71
+ def save_chat_history(current_model, *args):
72
+ return current_model.save_chat_history(*args)
73
+
74
+ def export_markdown(current_model, *args):
75
+ return current_model.export_markdown(*args)
76
+
77
+ def load_chat_history(current_model, *args):
78
+ return current_model.load_chat_history(*args)
79
+
80
+ def set_token_upper_limit(current_model, *args):
81
+ return current_model.set_token_upper_limit(*args)
82
+
83
+ def set_temperature(current_model, *args):
84
+ current_model.set_temperature(*args)
85
+
86
+ def set_top_p(current_model, *args):
87
+ current_model.set_top_p(*args)
88
+
89
+ def set_n_choices(current_model, *args):
90
+ current_model.set_n_choices(*args)
91
+
92
+ def set_stop_sequence(current_model, *args):
93
+ current_model.set_stop_sequence(*args)
94
+
95
+ def set_max_tokens(current_model, *args):
96
+ current_model.set_max_tokens(*args)
97
+
98
+ def set_presence_penalty(current_model, *args):
99
+ current_model.set_presence_penalty(*args)
100
+
101
+ def set_frequency_penalty(current_model, *args):
102
+ current_model.set_frequency_penalty(*args)
103
+
104
+ def set_logit_bias(current_model, *args):
105
+ current_model.set_logit_bias(*args)
106
+
107
+ def set_user_identifier(current_model, *args):
108
+ current_model.set_user_identifier(*args)
109
+
110
+ def set_single_turn(current_model, *args):
111
+ current_model.set_single_turn(*args)
112
 
113
 
114
  def count_token(message):