alexkueck commited on
Commit
9808c6b
·
1 Parent(s): fc500c0

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +27 -0
utils.py CHANGED
@@ -79,6 +79,33 @@ def load_tokenizer_and_model(base_model, load_8bit=False):
79
  return tokenizer,model, device
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def load_tokenizer(base_model):
83
  if torch.cuda.is_available():
84
  device = "cuda"
 
79
  return tokenizer,model, device
80
 
81
 
82
+ def load_model(base_model, load_8bit=False):
83
+ if torch.cuda.is_available():
84
+ device = "cuda"
85
+ else:
86
+ device = "cpu"
87
+
88
+ if device == "cuda":
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ base_model,
91
+ load_in_8bit=load_8bit,
92
+ torch_dtype=torch.float16,
93
+ device_map="auto"
94
+ )
95
+ else:
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ base_model, device_map={"": device}, low_cpu_mem_usage=True
98
+ )
99
+
100
+ #if not load_8bit:
101
+ #model.half() # seems to fix bugs for some users.
102
+
103
+ model.eval()
104
+ return model, device
105
+
106
+
107
+
108
+
109
  def load_tokenizer(base_model):
110
  if torch.cuda.is_available():
111
  device = "cuda"