File size: 762 Bytes
6ef7ab3
 
 
 
 
 
 
 
020afa7
 
6ef7ab3
 
 
 
 
020afa7
 
 
6ef7ab3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from omegaconf import OmegaConf
from swim.utils import instantiate_from_config
from torchinfo import summary
from swim.modules.dataset import SwimDataModule
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger

torch.set_float32_matmul_precision("medium")

config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")

model = instantiate_from_config(config.model)
model.learning_rate = config.model.base_learning_rate

datamodule = SwimDataModule(
    root_dir="/cm/shared/ninhnq3/datasets/swim_data", batch_size=2, img_size=512
)

logger = WandbLogger(project="swim", name="autoencoder_kl")

trainer = Trainer(max_epochs=10, devices=[0], logger=logger, log_every_n_steps=10)
trainer.fit(model, datamodule)