import torch import torch.nn as nn import torch.nn.functional as F class ConvBNBlock(nn.Module): def __init__(self, in_planes, planes, stride=1, p=0.0): super(ConvBNBlock, self).__init__() self.dropout_prob = p self.conv_bn_block = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(planes) ) self.drop_out = nn.Dropout2d(p=self.dropout_prob) def forward(self, x): out =F.relu(self.drop_out(self.conv_bn_block(x)) ) return out class TransitionBlock(nn.Module): def __init__(self, in_planes, planes, stride=1, p=0.0): super(TransitionBlock, self).__init__() self.p = p self.transition_block = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(planes), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Dropout2d(p=self.p) ) def forward(self, x): x = self.transition_block(x) return x class ResBlock(nn.Module): def __init__(self, in_planes, planes, stride=1, p=0.0): super(ResBlock, self).__init__() self.p = p self.transition_block = TransitionBlock(in_planes, planes, stride, p) self.conv_block1 = ConvBNBlock(planes, planes, stride, p) self.conv_block2 = ConvBNBlock(planes, planes, stride, p) def forward(self, x): x = self.transition_block(x) r = self.conv_block2(self.conv_block1(x)) out = x + r return out class CustomResNet(nn.Module): def __init__(self, p=0.0, num_classes=10): super(CustomResNet, self).__init__() self.in_planes = 64 self.p = p self.conv = ConvBNBlock(3, 64, 1, p) self.layer1 = ResBlock(64, 128, 1, p) self.layer2 = TransitionBlock(128, 256, 1, p) self.layer3 = ResBlock(256, 512, 1, p) self.max_pool = nn.MaxPool2d(4, 4) self.linear = nn.Linear(512, num_classes) def forward(self, x): out = self.conv(x) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.max_pool(out) out = out.view(out.size(0), -1) out = self.linear(out) return F.log_softmax(out, dim=1)