Spaces:
Runtime error
Runtime error
File size: 1,325 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone
@MODELS.register_module()
class LeNet5(BaseBackbone):
"""`LeNet5 <https://en.wikipedia.org/wiki/LeNet>`_ backbone.
The input for LeNet-5 is a 32×32 grayscale image.
Args:
num_classes (int): number of classes for classification.
The default value is -1, which uses the backbone as
a feature extractor without the top classifier.
"""
def __init__(self, num_classes=-1):
super(LeNet5, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1), nn.Tanh(),
nn.AvgPool2d(kernel_size=2),
nn.Conv2d(6, 16, kernel_size=5, stride=1), nn.Tanh(),
nn.AvgPool2d(kernel_size=2),
nn.Conv2d(16, 120, kernel_size=5, stride=1), nn.Tanh())
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(120, 84),
nn.Tanh(),
nn.Linear(84, num_classes),
)
def forward(self, x):
x = self.features(x)
if self.num_classes > 0:
x = self.classifier(x.squeeze())
return (x, )
|