danyloylo commited on
Commit
65f6692
1 Parent(s): b06793d

Delete train_laion_face.py

Browse files
Files changed (1) hide show
  1. train_laion_face.py +0 -46
train_laion_face.py DELETED
@@ -1,46 +0,0 @@
1
- from share import *
2
-
3
- import pytorch_lightning as pl
4
- from torch.utils.data import DataLoader
5
- from laion_face_dataset import LaionDataset
6
- from cldm.logger import ImageLogger
7
- from cldm.model import create_model, load_state_dict
8
-
9
-
10
- # Configs
11
- resume_path = './models/controlnet_sd21_laion_face.ckpt'
12
- batch_size = 4
13
- logger_freq = 2500
14
- learning_rate = 1e-5
15
- sd_locked = True
16
- only_mid_control = False
17
-
18
-
19
- # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
20
- model = create_model('./models/cldm_v21.yaml').cpu()
21
- model.load_state_dict(load_state_dict(resume_path, location='cpu'))
22
- model.learning_rate = learning_rate
23
- model.sd_locked = sd_locked
24
- model.only_mid_control = only_mid_control
25
-
26
-
27
- # Save every so often:
28
- ckpt_callback = pl.callbacks.ModelCheckpoint(
29
- dirpath="./checkpoints/",
30
- filename="ckpt_controlnet_sd21_{epoch}_{step}_{loss}",
31
- monitor='train/loss_simple_step',
32
- save_top_k=5,
33
- every_n_train_steps=5000,
34
- save_last=True,
35
- )
36
-
37
-
38
- # Misc
39
- dataset = LaionDataset()
40
- dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
41
- logger = ImageLogger(batch_frequency=logger_freq)
42
- trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
43
-
44
-
45
- # Train!
46
- trainer.fit(model, dataloader)