|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class ConvBNBlock(nn.Module): |
|
def __init__(self, in_planes, planes, stride=1, p=0.0): |
|
super(ConvBNBlock, self).__init__() |
|
self.dropout_prob = p |
|
self.conv_bn_block = nn.Sequential( |
|
nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False), |
|
nn.BatchNorm2d(planes) |
|
) |
|
self.drop_out = nn.Dropout2d(p=self.dropout_prob) |
|
|
|
def forward(self, x): |
|
out =F.relu(self.drop_out(self.conv_bn_block(x)) ) |
|
return out |
|
|
|
class TransitionBlock(nn.Module): |
|
def __init__(self, in_planes, planes, stride=1, p=0.0): |
|
super(TransitionBlock, self).__init__() |
|
self.p = p |
|
self.transition_block = nn.Sequential( |
|
nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False), |
|
nn.BatchNorm2d(planes), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(2, 2), |
|
nn.Dropout2d(p=self.p) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.transition_block(x) |
|
return x |
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, in_planes, planes, stride=1, p=0.0): |
|
super(ResBlock, self).__init__() |
|
self.p = p |
|
self.transition_block = TransitionBlock(in_planes, planes, stride, p) |
|
self.conv_block1 = ConvBNBlock(planes, planes, stride, p) |
|
self.conv_block2 = ConvBNBlock(planes, planes, stride, p) |
|
|
|
|
|
def forward(self, x): |
|
x = self.transition_block(x) |
|
r = self.conv_block2(self.conv_block1(x)) |
|
out = x + r |
|
return out |
|
|
|
class CustomResNet(nn.Module): |
|
def __init__(self, p=0.0, num_classes=10): |
|
super(CustomResNet, self).__init__() |
|
self.in_planes = 64 |
|
self.p = p |
|
|
|
self.conv = ConvBNBlock(3, 64, 1, p) |
|
self.layer1 = ResBlock(64, 128, 1, p) |
|
self.layer2 = TransitionBlock(128, 256, 1, p) |
|
self.layer3 = ResBlock(256, 512, 1, p) |
|
self.max_pool = nn.MaxPool2d(4, 4) |
|
self.linear = nn.Linear(512, num_classes) |
|
|
|
def forward(self, x): |
|
out = self.conv(x) |
|
out = self.layer1(out) |
|
out = self.layer2(out) |
|
out = self.layer3(out) |
|
out = self.max_pool(out) |
|
out = out.view(out.size(0), -1) |
|
out = self.linear(out) |
|
return F.log_softmax(out, dim=1) |
|
|