import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(Net, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, num_classes) self.softmax = nn.Softmax(dim=1) def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) out = self.softmax(out) return out if __name__ == '__main__': net = Net(100, 50, 10) torch.save(net.state_dict(), 'model.pth')