import torch import torch.nn as nn from torch.nn import functional as F import timm class CNNSmall(nn.Module): def __init__(self): super().__init__() self.module = nn.Sequential( nn.Conv2d(3, 8, 5), nn.MaxPool2d(2, 2), nn.LeakyReLU(), nn.Conv2d(8, 6, 5), nn.MaxPool2d(2, 2), nn.LeakyReLU(), nn.Conv2d(6, 4, 2), nn.LeakyReLU(), nn.Flatten(start_dim=1), ) self.head = nn.Sequential( nn.Linear(36, 20), nn.LeakyReLU(), nn.Linear(20, 10), ) def forward(self, x): x = F.interpolate(x, (28, 28), mode='bilinear') x = self.module(x) x = self.head(x) return x def Model(): model = CNNSmall() return model, model.head if __name__ == "__main__": model, _ = Model() x = torch.ones([4, 3, 28, 28]) y = model(x) print(y.shape) print(model) num_param = 0 for v in model.parameters(): num_param += v.numel() print("num_param:", num_param)