Tuchuanhuhuhu commited on
Commit
31c7630
·
1 Parent(s): 180ee81

feat: Qwen支持加载本地/自定义模型

Browse files
Files changed (1) hide show
  1. modules/models/Qwen.py +13 -2
modules/models/Qwen.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  from transformers.generation import GenerationConfig
3
  import logging
4
  import colorama
@@ -9,8 +10,18 @@ from ..presets import MODEL_METADATA
9
  class Qwen_Client(BaseLLMModel):
10
  def __init__(self, model_name, user_name="") -> None:
11
  super().__init__(model_name=model_name, user=user_name)
12
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True)
13
- self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval()
 
 
 
 
 
 
 
 
 
 
14
 
15
  def generation_config(self):
16
  return GenerationConfig.from_dict({
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import os
3
  from transformers.generation import GenerationConfig
4
  import logging
5
  import colorama
 
10
  class Qwen_Client(BaseLLMModel):
11
  def __init__(self, model_name, user_name="") -> None:
12
  super().__init__(model_name=model_name, user=user_name)
13
+ model_source = None
14
+ if os.path.exists("models"):
15
+ model_dirs = os.listdir("models")
16
+ if model_name in model_dirs:
17
+ model_source = f"models/{model_name}"
18
+ if model_source is None:
19
+ try:
20
+ model_source = MODEL_METADATA[model_name]["repo_id"]
21
+ except KeyError:
22
+ model_source = model_name
23
+ self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True, resume_download=True)
24
+ self.model = AutoModelForCausalLM.from_pretrained(model_source, device_map="auto", trust_remote_code=True, resume_download=True).eval()
25
 
26
  def generation_config(self):
27
  return GenerationConfig.from_dict({