Cancer-Risk-Prediction-from-WSI / classes /binary_neural_classifier.py
VatsalPatel18's picture
Upload 32 files
7b1328a verified
raw
history blame
399 Bytes
import torch.nn as nn
import torch
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(512, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x