swim / main.py
qninhdt's picture
cc
6ef7ab3
raw
history blame
645 Bytes
import torch
from omegaconf import OmegaConf
from swim.utils import instantiate_from_config
from torchinfo import summary
from swim.modules.dataset import SwimDataModule
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")
model = instantiate_from_config(config.model)
model.learning_rate = config.model.base_learning_rate
datamodule = SwimDataModule(img_size=32)
logger = WandbLogger(project="swim", name="autoencoder_kl")
trainer = Trainer(max_epochs=10, devices=[0], logger=logger, log_every_n_steps=10)
trainer.fit(model, datamodule)