Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torchvision.transforms import functional as f | |
class UNet(torch.nn.Module): | |
def __init__(self, device, in_channels: int = 3, num_classes: int = 3) -> None: | |
super().__init__() | |
self.block_1 = nn.Sequential( | |
nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), #-> Channels = 64 | |
nn.BatchNorm2d(64, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64 | |
nn.BatchNorm2d(64, device=device), | |
nn.ReLU() | |
) | |
self.max_pool_2x2_1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.block_2 = nn.Sequential( | |
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128 | |
nn.BatchNorm2d(128, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128 | |
nn.BatchNorm2d(128, device=device), | |
nn.ReLU() | |
) | |
self.max_pool_2x2_2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.block_3 = nn.Sequential( | |
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256 | |
nn.BatchNorm2d(256, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256 | |
nn.BatchNorm2d(256, device=device), | |
nn.ReLU() | |
) | |
self.max_pool_2x2_3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.block_4 = nn.Sequential( | |
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512 | |
nn.BatchNorm2d(512, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512 | |
nn.BatchNorm2d(512, device=device), | |
nn.ReLU() | |
) | |
self.drop_out_1 = nn.Dropout(p=0.5) | |
self.max_pool_2x2_4 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.block_5 = nn.Sequential( | |
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024 | |
nn.BatchNorm2d(1024, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 1024 | |
nn.BatchNorm2d(1024, device=device), | |
nn.ReLU() | |
) | |
self.drop_out_2 = nn.Dropout(p=0.5) | |
self.up_conv_2x2_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 512 | |
#after up_sampled, the tensor will be concatenate with the output of the block_4 which is a 512-channels tensor | |
# so that the tensor to put in the block 6 will be a (512 + 512)-channels = 1024-channels tensor | |
self.block_6 = nn.Sequential( | |
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 512 | |
nn.BatchNorm2d(512, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 512 | |
nn.BatchNorm2d(512, device=device), | |
nn.ReLU() | |
) | |
self.drop_out_3 = nn.Dropout(p=0.5) | |
self.up_conv_2x2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 256 | |
#The same as up_conv_2x2_1 | |
self.block_7 = nn.Sequential( | |
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 256 | |
nn.BatchNorm2d(256, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 256 | |
nn.BatchNorm2d(256, device=device), | |
nn.ReLU() | |
) | |
self.up_conv_2x2_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 128 | |
#The same as up_conv_2x2_1 | |
self.block_8 = nn.Sequential( | |
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 128 | |
nn.BatchNorm2d(128, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 128 | |
nn.BatchNorm2d(128, device=device), | |
nn.ReLU() | |
) | |
self.up_conv_2x2_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0, device=device) # -> channels = 64 | |
#The same as up_conv_2x2_1 | |
self.block_9 = nn.Sequential( | |
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, stride=1, device=device), # -> channels = 64 | |
nn.BatchNorm2d(64, device=device), | |
nn.ReLU(), | |
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, device=device), # -> channels = 64 | |
nn.BatchNorm2d(64, device=device), | |
nn.ReLU() | |
) | |
self.last_conv_1x1 = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1, padding=0, device=device) # -> channels = num_classes (default = 3 for [background, borders, objects]) | |
def forward(self, x): | |
block_1_result = self.block_1(x) | |
block_2_result = self.block_2(self.max_pool_2x2_1(block_1_result)) | |
block_3_result = self.block_3(self.max_pool_2x2_2(block_2_result)) | |
block_4_result = self.block_4(self.max_pool_2x2_3(block_3_result)) | |
block_4_result = self.drop_out_1(block_4_result) | |
block_5_result = self.block_5(self.max_pool_2x2_4(block_4_result)) | |
block_5_result = self.drop_out_2(block_5_result) | |
up_conv_1_result = self.up_conv_2x2_1(block_5_result) | |
block_4_result = f.center_crop(block_4_result, [up_conv_1_result.shape[2], up_conv_1_result.shape[3]]) | |
concat_1_result = torch.cat([block_4_result, up_conv_1_result], axis=1) | |
block_6_result = self.block_6(concat_1_result) | |
block_6_result = self.drop_out_3(block_6_result) | |
up_conv_2_result = self.up_conv_2x2_2(block_6_result) | |
block_3_result = f.center_crop(block_3_result, [up_conv_2_result.shape[2], up_conv_2_result.shape[3]]) | |
concat_2_result = torch.cat([block_3_result, up_conv_2_result], axis=1) | |
block_7_result = self.block_7(concat_2_result) | |
up_conv_3_result = self.up_conv_2x2_3(block_7_result) | |
block_2_result = f.center_crop(block_2_result, [up_conv_3_result.shape[2], up_conv_3_result.shape[3]]) | |
concat_3_result = torch.cat([block_2_result, up_conv_3_result], axis=1) | |
block_8_result = self.block_8(concat_3_result) | |
up_conv_4_result = self.up_conv_2x2_4(block_8_result) | |
block_1_result = f.center_crop(block_1_result, [up_conv_4_result.shape[2], up_conv_4_result.shape[3]]) | |
concat_4_result = torch.cat([block_1_result, up_conv_4_result], axis=1) | |
block_9_result = self.block_9(concat_4_result) | |
last_block_result = self.last_conv_1x1(block_9_result) | |
return last_block_result | |