train model
Browse files- scripts/train_model.py +3 -2
scripts/train_model.py
CHANGED
@@ -221,7 +221,8 @@ config.torch_dtype = torch.bfloat16
|
|
221 |
print(config)
|
222 |
|
223 |
model = AutoModelForCausalLM.from_config(config)
|
224 |
-
model = model.to(torch.bfloat16)
|
|
|
225 |
print(model)
|
226 |
|
227 |
training_args = TrainingArguments(
|
@@ -238,7 +239,7 @@ training_args = TrainingArguments(
|
|
238 |
logging_steps=10,
|
239 |
fp16=False,
|
240 |
bf16=True,
|
241 |
-
|
242 |
)
|
243 |
print(training_args)
|
244 |
|
|
|
221 |
print(config)
|
222 |
|
223 |
model = AutoModelForCausalLM.from_config(config)
|
224 |
+
# model = model.to(torch.bfloat16)
|
225 |
+
model = torch.compile(model)
|
226 |
print(model)
|
227 |
|
228 |
training_args = TrainingArguments(
|
|
|
239 |
logging_steps=10,
|
240 |
fp16=False,
|
241 |
bf16=True,
|
242 |
+
torch_compile=True,
|
243 |
)
|
244 |
print(training_args)
|
245 |
|