Bread / models /networks /networks.py
huqiming513's picture
Upload 7 files
03b684c
from models.networks.modules import *
class BaseNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, norm=True):
super(BaseNet, self).__init__()
self.n_channels = in_channels
self.n_classes = out_channels
self.inc = DoubleConv(in_channels, 32, norm=norm)
self.down1 = Down(32, 64, norm=norm)
self.down2 = Down(64, 128, norm=norm)
self.down3 = Down(128, 128, norm=norm)
self.up1 = Up(256, 64, bilinear=True, norm=norm)
self.up2 = Up(128, 32, bilinear=True, norm=norm)
self.up3 = Up(64, 32, bilinear=True, norm=norm)
self.outc = OutConv(32, out_channels)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
logits = self.outc(x)
return logits
class IAN(BaseNet):
def __init__(self, in_channels=1, out_channels=1, norm=True):
super(IAN, self).__init__(in_channels, out_channels, norm)
class ANSN(BaseNet):
def __init__(self, in_channels=1, out_channels=1, norm=True):
super(ANSN, self).__init__(in_channels, out_channels, norm)
self.outc = OutConv(32, out_channels, act=False)
class FuseNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, norm=False):
super(FuseNet, self).__init__()
self.inc = AttentiveDoubleConv(in_channels, 32, norm=norm, leaky=False)
self.down1 = AttentiveDown(32, 64, norm=norm, leaky=False)
self.down2 = AttentiveDown(64, 64, norm=norm, leaky=False)
self.up1 = AttentiveUp(128, 32, bilinear=True, norm=norm, leaky=False)
self.up2 = AttentiveUp(64, 32, bilinear=True, norm=norm, leaky=False)
self.outc = OutConv(32, out_channels)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x = self.up1(x3, x2)
x = self.up2(x, x1)
logits = self.outc(x)
return logits
if __name__ == '__main__':
for key in FuseNet(4, 2).state_dict().keys():
print(key)