import torch from torchinfo import summary from swim.encoder import SwimEncoder encoder = SwimEncoder().to("meta") sample = torch.randn(1, 3, 512, 512).to("meta") summary(encoder, input_data=(sample,))