johnsmith253325 commited on
Commit
b346648
·
1 Parent(s): 7d0f396

feat: 加入LoRA功能

Browse files
Files changed (2) hide show
  1. modules/models/LLaMA.py +49 -35
  2. modules/models/models.py +2 -3
modules/models/LLaMA.py CHANGED
@@ -11,10 +11,6 @@ from ..presets import *
11
  from ..utils import *
12
  from .base_model import BaseLLMModel
13
 
14
- import json
15
- from llama_cpp import Llama
16
- from huggingface_hub import hf_hub_download
17
-
18
  SYS_PREFIX = "<<SYS>>\n"
19
  SYS_POSTFIX = "\n<</SYS>>\n\n"
20
  INST_PREFIX = "<s>[INST] "
@@ -22,6 +18,7 @@ INST_POSTFIX = " "
22
  OUTPUT_PREFIX = "[/INST] "
23
  OUTPUT_POSTFIX = "</s>"
24
 
 
25
  def download(repo_id, filename, retry=10):
26
  if os.path.exists("./models/downloaded_models.json"):
27
  with open("./models/downloaded_models.json", "r") as f:
@@ -32,7 +29,12 @@ def download(repo_id, filename, retry=10):
32
  downloaded_models = {}
33
  while retry > 0:
34
  try:
35
- model_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir="models", resume_download=True)
 
 
 
 
 
36
  downloaded_models[repo_id] = {"path": model_path}
37
  with open("./models/downloaded_models.json", "w") as f:
38
  json.dump(downloaded_models, f)
@@ -46,57 +48,69 @@ def download(repo_id, filename, retry=10):
46
 
47
 
48
  class LLaMA_Client(BaseLLMModel):
49
- def __init__(
50
- self,
51
- model_name,
52
- lora_path=None,
53
- user_name=""
54
- ) -> None:
55
  super().__init__(model_name=model_name, user=user_name)
56
 
57
  self.max_generation_token = 1000
58
- self.end_string = "\n\n"
59
- # We don't need input data
60
- path_to_model = download(MODEL_METADATA[model_name]["repo_id"], MODEL_METADATA[model_name]["filelist"][0])
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  self.system_prompt = ""
62
 
63
- global LLAMA_MODEL
64
- if LLAMA_MODEL is None:
65
- LLAMA_MODEL = Llama(model_path=path_to_model)
66
- # model_path = None
67
- # if os.path.exists("models"):
68
- # model_dirs = os.listdir("models")
69
- # if model_name in model_dirs:
70
- # model_path = f"models/{model_name}"
71
- # if model_path is not None:
72
- # model_source = model_path
73
- # else:
74
- # model_source = f"decapoda-research/{model_name}"
75
- # raise Exception(f"models目录下没有这个模型: {model_name}")
76
- # if lora_path is not None:
77
- # lora_path = f"lora/{lora_path}"
78
 
79
  def _get_llama_style_input(self):
80
  context = []
81
  for conv in self.history:
82
  if conv["role"] == "system":
83
- context.append(SYS_PREFIX+conv["content"]+SYS_POSTFIX)
84
  elif conv["role"] == "user":
85
- context.append(INST_PREFIX+conv["content"]+INST_POSTFIX+OUTPUT_PREFIX)
 
 
86
  else:
87
- context.append(conv["content"]+OUTPUT_POSTFIX)
88
  return "".join(context)
89
 
90
  def get_answer_at_once(self):
91
  context = self._get_llama_style_input()
92
- response = LLAMA_MODEL(context, max_tokens=self.max_generation_token, stop=[], echo=False, stream=False)
 
 
 
 
 
 
93
  return response, len(response)
94
 
95
  def get_answer_stream_iter(self):
96
  context = self._get_llama_style_input()
97
- iter = LLAMA_MODEL(context, max_tokens=self.max_generation_token, stop=[], echo=False, stream=True)
 
 
 
 
 
 
98
  partial_text = ""
99
  for i in iter:
100
  response = i["choices"][0]["text"]
101
  partial_text += response
102
- yield partial_text
 
11
  from ..utils import *
12
  from .base_model import BaseLLMModel
13
 
 
 
 
 
14
  SYS_PREFIX = "<<SYS>>\n"
15
  SYS_POSTFIX = "\n<</SYS>>\n\n"
16
  INST_PREFIX = "<s>[INST] "
 
18
  OUTPUT_PREFIX = "[/INST] "
19
  OUTPUT_POSTFIX = "</s>"
20
 
21
+
22
  def download(repo_id, filename, retry=10):
23
  if os.path.exists("./models/downloaded_models.json"):
24
  with open("./models/downloaded_models.json", "r") as f:
 
29
  downloaded_models = {}
30
  while retry > 0:
31
  try:
32
+ model_path = hf_hub_download(
33
+ repo_id=repo_id,
34
+ filename=filename,
35
+ cache_dir="models",
36
+ resume_download=True,
37
+ )
38
  downloaded_models[repo_id] = {"path": model_path}
39
  with open("./models/downloaded_models.json", "w") as f:
40
  json.dump(downloaded_models, f)
 
48
 
49
 
50
  class LLaMA_Client(BaseLLMModel):
51
+ def __init__(self, model_name, lora_path=None, user_name="") -> None:
 
 
 
 
 
52
  super().__init__(model_name=model_name, user=user_name)
53
 
54
  self.max_generation_token = 1000
55
+ if model_name in MODEL_METADATA:
56
+ path_to_model = download(
57
+ MODEL_METADATA[model_name]["repo_id"],
58
+ MODEL_METADATA[model_name]["filelist"][0],
59
+ )
60
+ else:
61
+ dir_to_model = os.path.join("models", model_name)
62
+ # look for nay .gguf file in the dir_to_model directory and its subdirectories
63
+ path_to_model = None
64
+ for root, dirs, files in os.walk(dir_to_model):
65
+ for file in files:
66
+ if file.endswith(".gguf"):
67
+ path_to_model = os.path.join(root, file)
68
+ break
69
+ if path_to_model is not None:
70
+ break
71
  self.system_prompt = ""
72
 
73
+ if lora_path is not None:
74
+ lora_path = os.path.join("lora", lora_path)
75
+ self.model = Llama(model_path=path_to_model, lora_path=lora_path)
76
+ else:
77
+ self.model = Llama(model_path=path_to_model)
 
 
 
 
 
 
 
 
 
 
78
 
79
  def _get_llama_style_input(self):
80
  context = []
81
  for conv in self.history:
82
  if conv["role"] == "system":
83
+ context.append(SYS_PREFIX + conv["content"] + SYS_POSTFIX)
84
  elif conv["role"] == "user":
85
+ context.append(
86
+ INST_PREFIX + conv["content"] + INST_POSTFIX + OUTPUT_PREFIX
87
+ )
88
  else:
89
+ context.append(conv["content"] + OUTPUT_POSTFIX)
90
  return "".join(context)
91
 
92
  def get_answer_at_once(self):
93
  context = self._get_llama_style_input()
94
+ response = self.model(
95
+ context,
96
+ max_tokens=self.max_generation_token,
97
+ stop=[],
98
+ echo=False,
99
+ stream=False,
100
+ )
101
  return response, len(response)
102
 
103
  def get_answer_stream_iter(self):
104
  context = self._get_llama_style_input()
105
+ iter = self.model(
106
+ context,
107
+ max_tokens=self.max_generation_token,
108
+ stop=[],
109
+ echo=False,
110
+ stream=True,
111
+ )
112
  partial_text = ""
113
  for i in iter:
114
  response = i["choices"][0]["text"]
115
  partial_text += response
116
+ yield partial_text
modules/models/models.py CHANGED
@@ -26,7 +26,7 @@ def get_model(
26
  msg = i18n("模型设置为了:") + f" {model_name}"
27
  model_type = ModelType.get_type(model_name)
28
  lora_selector_visibility = False
29
- lora_choices = []
30
  dont_change_lora_selector = False
31
  if model_type != ModelType.OpenAI:
32
  config.local_embedding = True
@@ -55,8 +55,7 @@ def get_model(
55
  logging.info(msg)
56
  lora_selector_visibility = True
57
  if os.path.isdir("lora"):
58
- get_file_names_by_pinyin("lora", filetypes=[""])
59
- lora_choices = ["No LoRA"] + lora_choices
60
  elif model_type == ModelType.LLaMA and lora_model_path != "":
61
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
62
  from .LLaMA import LLaMA_Client
 
26
  msg = i18n("模型设置为了:") + f" {model_name}"
27
  model_type = ModelType.get_type(model_name)
28
  lora_selector_visibility = False
29
+ lora_choices = ["No LoRA"]
30
  dont_change_lora_selector = False
31
  if model_type != ModelType.OpenAI:
32
  config.local_embedding = True
 
55
  logging.info(msg)
56
  lora_selector_visibility = True
57
  if os.path.isdir("lora"):
58
+ lora_choices = ["No LoRA"] + get_file_names_by_pinyin("lora", filetypes=[""])
 
59
  elif model_type == ModelType.LLaMA and lora_model_path != "":
60
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
61
  from .LLaMA import LLaMA_Client