Delete train_laion_face.py
Browse files- train_laion_face.py +0 -46
train_laion_face.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
from share import *
|
2 |
-
|
3 |
-
import pytorch_lightning as pl
|
4 |
-
from torch.utils.data import DataLoader
|
5 |
-
from laion_face_dataset import LaionDataset
|
6 |
-
from cldm.logger import ImageLogger
|
7 |
-
from cldm.model import create_model, load_state_dict
|
8 |
-
|
9 |
-
|
10 |
-
# Configs
|
11 |
-
resume_path = './models/controlnet_sd21_laion_face.ckpt'
|
12 |
-
batch_size = 4
|
13 |
-
logger_freq = 2500
|
14 |
-
learning_rate = 1e-5
|
15 |
-
sd_locked = True
|
16 |
-
only_mid_control = False
|
17 |
-
|
18 |
-
|
19 |
-
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
20 |
-
model = create_model('./models/cldm_v21.yaml').cpu()
|
21 |
-
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
22 |
-
model.learning_rate = learning_rate
|
23 |
-
model.sd_locked = sd_locked
|
24 |
-
model.only_mid_control = only_mid_control
|
25 |
-
|
26 |
-
|
27 |
-
# Save every so often:
|
28 |
-
ckpt_callback = pl.callbacks.ModelCheckpoint(
|
29 |
-
dirpath="./checkpoints/",
|
30 |
-
filename="ckpt_controlnet_sd21_{epoch}_{step}_{loss}",
|
31 |
-
monitor='train/loss_simple_step',
|
32 |
-
save_top_k=5,
|
33 |
-
every_n_train_steps=5000,
|
34 |
-
save_last=True,
|
35 |
-
)
|
36 |
-
|
37 |
-
|
38 |
-
# Misc
|
39 |
-
dataset = LaionDataset()
|
40 |
-
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
41 |
-
logger = ImageLogger(batch_frequency=logger_freq)
|
42 |
-
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
|
43 |
-
|
44 |
-
|
45 |
-
# Train!
|
46 |
-
trainer.fit(model, dataloader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|