Spaces:
Runtime error
Runtime error
File size: 2,275 Bytes
cb80c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
'''
For MEMO implementations of ImageNet-ConvNet
Reference:
https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py
'''
import torch.nn as nn
import torch
# for imagenet
def first_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
class ConvNet(nn.Module):
def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
super().__init__()
self.block1 = first_block(x_dim, hid_dim)
self.block2 = conv_block(hid_dim, hid_dim)
self.block3 = conv_block(hid_dim, hid_dim)
self.block4 = conv_block(hid_dim, z_dim)
self.avgpool = nn.AvgPool2d(7)
self.out_dim = 512
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.avgpool(x)
features = x.view(x.shape[0], -1)
return {
"features": features
}
class GeneralizedConvNet(nn.Module):
def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
super().__init__()
self.block1 = first_block(x_dim, hid_dim)
self.block2 = conv_block(hid_dim, hid_dim)
self.block3 = conv_block(hid_dim, hid_dim)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return x
class SpecializedConvNet(nn.Module):
def __init__(self, hid_dim=128,z_dim=512):
super().__init__()
self.block4 = conv_block(hid_dim, z_dim)
self.avgpool = nn.AvgPool2d(7)
self.feature_dim = 512
def forward(self, x):
x = self.block4(x)
x = self.avgpool(x)
features = x.view(x.shape[0], -1)
return features
def conv4():
model = ConvNet()
return model
def conv_a2fc_imagenet():
_base = GeneralizedConvNet()
_adaptive_net = SpecializedConvNet()
return _base, _adaptive_net |