mtasic85 commited on
Commit
ace007f
1 Parent(s): 6d8ae94

train model

Browse files
Files changed (1) hide show
  1. scripts/train_model.py +3 -2
scripts/train_model.py CHANGED
@@ -221,7 +221,8 @@ config.torch_dtype = torch.bfloat16
221
  print(config)
222
 
223
  model = AutoModelForCausalLM.from_config(config)
224
- model = model.to(torch.bfloat16)
 
225
  print(model)
226
 
227
  training_args = TrainingArguments(
@@ -238,7 +239,7 @@ training_args = TrainingArguments(
238
  logging_steps=10,
239
  fp16=False,
240
  bf16=True,
241
- # torch_compile=True,
242
  )
243
  print(training_args)
244
 
 
221
  print(config)
222
 
223
  model = AutoModelForCausalLM.from_config(config)
224
+ # model = model.to(torch.bfloat16)
225
+ model = torch.compile(model)
226
  print(model)
227
 
228
  training_args = TrainingArguments(
 
239
  logging_steps=10,
240
  fp16=False,
241
  bf16=True,
242
+ torch_compile=True,
243
  )
244
  print(training_args)
245