train model
Browse files- scripts/train_model.py +1 -1
scripts/train_model.py
CHANGED
@@ -224,7 +224,7 @@ config.torch_dtype = torch.bfloat16
|
|
224 |
print(config)
|
225 |
|
226 |
model = AutoModelForCausalLM.from_config(config)
|
227 |
-
|
228 |
model = torch.compile(model)
|
229 |
model = model.to(device)
|
230 |
print(model)
|
|
|
224 |
print(config)
|
225 |
|
226 |
model = AutoModelForCausalLM.from_config(config)
|
227 |
+
model = model.to(torch.bfloat16)
|
228 |
model = torch.compile(model)
|
229 |
model = model.to(device)
|
230 |
print(model)
|