qninhdt commited on
Commit
95aa666
1 Parent(s): 0846051
train_autoencoder.sh → .nfs00000001516db020000048ee RENAMED
File without changes
configs/data/swim.yaml CHANGED
@@ -1,4 +1,4 @@
1
  _target_: swim.data.swim_data.SwimDataModule
2
  root_dir: /home/qninh/projects/swim_/datasets/swim_data
3
- batch_size: 4
4
  img_size: 64
 
1
  _target_: swim.data.swim_data.SwimDataModule
2
  root_dir: /home/qninh/projects/swim_/datasets/swim_data
3
+ batch_size: 8
4
  img_size: 64
configs/train.yaml CHANGED
@@ -45,5 +45,7 @@ test: True
45
  # simply provide checkpoint path to resume training
46
  ckpt_path: null
47
 
 
 
48
  # seed for random number generators in pytorch, numpy and python.random
49
  seed: 42
 
45
  # simply provide checkpoint path to resume training
46
  ckpt_path: null
47
 
48
+ compile: false
49
+
50
  # seed for random number generators in pytorch, numpy and python.random
51
  seed: 42
swim/models/autoencoder.py CHANGED
@@ -165,9 +165,12 @@ class Autoencoder(LightningModule):
165
  if batch_idx == 0:
166
  self.log_images(img, recon)
167
 
168
- # def setup(self, stage: str) -> None:
169
- # if self.hparams.compile and stage == "fit":
170
- # self.net = torch.compile(self.net)
 
 
 
171
 
172
  def log_images(self, ori_images, recon_images):
173
  """
 
165
  if batch_idx == 0:
166
  self.log_images(img, recon)
167
 
168
+ def compile(self):
169
+ self.encoder = torch.compile(self.encoder, "max-autotune")
170
+ self.decoder = torch.compile(self.decoder, "max-autotune")
171
+ self.quant_conv = torch.compile(self.quant_conv, "max-autotune")
172
+ self.post_quant_conv = torch.compile(self.post_quant_conv, "max-autotune")
173
+ self.lpips = torch.compile(self.lpips, "max-autotune")
174
 
175
  def log_images(self, ori_images, recon_images):
176
  """
swim/train.py CHANGED
@@ -90,6 +90,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
90
  # compute model learning rate
91
  model.learning_rate = cfg.data.batch_size * cfg.model.base_learning_rate
92
 
 
 
 
93
  log.info("Starting training!")
94
  trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
95
 
 
90
  # compute model learning rate
91
  model.learning_rate = cfg.data.batch_size * cfg.model.base_learning_rate
92
 
93
+ if cfg.compile:
94
+ model.compile()
95
+
96
  log.info("Starting training!")
97
  trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
98
 
train_autoencoder_64.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python swim/train.py \
2
+ data.root_dir=/cm/shared/ninhnq3/datasets/swim_data \
3
+ data.img_size=512 \
4
+ data.batch_size=16 \
5
+ compile=true \
6
+ callbacks.model_checkpoint.dirpath=/cm/shared/ninhnq3/checkpoints/swim/autoencoder/simple