qninhdt
commited on
Commit
•
95aa666
1
Parent(s):
0846051
cc
Browse files- train_autoencoder.sh → .nfs00000001516db020000048ee +0 -0
- configs/data/swim.yaml +1 -1
- configs/train.yaml +2 -0
- swim/models/autoencoder.py +6 -3
- swim/train.py +3 -0
- train_autoencoder_64.sh +6 -0
train_autoencoder.sh → .nfs00000001516db020000048ee
RENAMED
File without changes
|
configs/data/swim.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
_target_: swim.data.swim_data.SwimDataModule
|
2 |
root_dir: /home/qninh/projects/swim_/datasets/swim_data
|
3 |
-
batch_size:
|
4 |
img_size: 64
|
|
|
1 |
_target_: swim.data.swim_data.SwimDataModule
|
2 |
root_dir: /home/qninh/projects/swim_/datasets/swim_data
|
3 |
+
batch_size: 8
|
4 |
img_size: 64
|
configs/train.yaml
CHANGED
@@ -45,5 +45,7 @@ test: True
|
|
45 |
# simply provide checkpoint path to resume training
|
46 |
ckpt_path: null
|
47 |
|
|
|
|
|
48 |
# seed for random number generators in pytorch, numpy and python.random
|
49 |
seed: 42
|
|
|
45 |
# simply provide checkpoint path to resume training
|
46 |
ckpt_path: null
|
47 |
|
48 |
+
compile: false
|
49 |
+
|
50 |
# seed for random number generators in pytorch, numpy and python.random
|
51 |
seed: 42
|
swim/models/autoencoder.py
CHANGED
@@ -165,9 +165,12 @@ class Autoencoder(LightningModule):
|
|
165 |
if batch_idx == 0:
|
166 |
self.log_images(img, recon)
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
171 |
|
172 |
def log_images(self, ori_images, recon_images):
|
173 |
"""
|
|
|
165 |
if batch_idx == 0:
|
166 |
self.log_images(img, recon)
|
167 |
|
168 |
+
def compile(self):
|
169 |
+
self.encoder = torch.compile(self.encoder, "max-autotune")
|
170 |
+
self.decoder = torch.compile(self.decoder, "max-autotune")
|
171 |
+
self.quant_conv = torch.compile(self.quant_conv, "max-autotune")
|
172 |
+
self.post_quant_conv = torch.compile(self.post_quant_conv, "max-autotune")
|
173 |
+
self.lpips = torch.compile(self.lpips, "max-autotune")
|
174 |
|
175 |
def log_images(self, ori_images, recon_images):
|
176 |
"""
|
swim/train.py
CHANGED
@@ -90,6 +90,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
90 |
# compute model learning rate
|
91 |
model.learning_rate = cfg.data.batch_size * cfg.model.base_learning_rate
|
92 |
|
|
|
|
|
|
|
93 |
log.info("Starting training!")
|
94 |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
95 |
|
|
|
90 |
# compute model learning rate
|
91 |
model.learning_rate = cfg.data.batch_size * cfg.model.base_learning_rate
|
92 |
|
93 |
+
if cfg.compile:
|
94 |
+
model.compile()
|
95 |
+
|
96 |
log.info("Starting training!")
|
97 |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
98 |
|
train_autoencoder_64.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python swim/train.py \
|
2 |
+
data.root_dir=/cm/shared/ninhnq3/datasets/swim_data \
|
3 |
+
data.img_size=512 \
|
4 |
+
data.batch_size=16 \
|
5 |
+
compile=true \
|
6 |
+
callbacks.model_checkpoint.dirpath=/cm/shared/ninhnq3/checkpoints/swim/autoencoder/simple
|