alexkueck commited on
Commit
31c7c6f
·
1 Parent(s): c856a9e

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +3 -1
utils.py CHANGED
@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
22
  import datasets
23
  from datasets import load_dataset
24
  import evaluate
 
25
 
26
 
27
 
@@ -91,7 +92,8 @@ def load_tokenizer_and_model_Blaize(base_model, load_8bit=True):
91
 
92
  tokenizer = LlamaTokenizer.from_pretrained(base_model, add_eos_token=True, use_auth_token=True)
93
  model = LlamaForCausalLM.from_pretrained(base_model, load_in_8bit=True, device_map="auto")
94
- #model.eval()
 
95
  return tokenizer,model, device
96
 
97
 
 
22
  import datasets
23
  from datasets import load_dataset
24
  import evaluate
25
+ from transformers import LlamaForCausalLM, LlamaTokenizer
26
 
27
 
28
 
 
92
 
93
  tokenizer = LlamaTokenizer.from_pretrained(base_model, add_eos_token=True, use_auth_token=True)
94
  model = LlamaForCausalLM.from_pretrained(base_model, load_in_8bit=True, device_map="auto")
95
+ model = prepare_model_for_int8_training(model)
96
+
97
  return tokenizer,model, device
98
 
99