swim / train.py
qninhdt's picture
cc
9b66f69
raw
history blame
204 Bytes
import torch
from torchinfo import summary
from swim.encoder import SwimEncoder
encoder = SwimEncoder().to("meta")
sample = torch.randn(1, 3, 512, 512).to("meta")
summary(encoder, input_data=(sample,))