handsomeboyMMk
commited on
Commit
·
269009c
1
Parent(s):
d984345
Upload 2 files
Browse files- generator.py +55 -0
- modules.py +63 -0
generator.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from .modules import Conv2dBlock, Concat
|
4 |
+
|
5 |
+
class SkipEncoderDecoder(nn.Module):
|
6 |
+
def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5):
|
7 |
+
super(SkipEncoderDecoder, self).__init__()
|
8 |
+
|
9 |
+
self.model = nn.Sequential()
|
10 |
+
model_tmp = self.model
|
11 |
+
|
12 |
+
for i in range(len(num_channels_down)):
|
13 |
+
|
14 |
+
deeper = nn.Sequential()
|
15 |
+
skip = nn.Sequential()
|
16 |
+
|
17 |
+
if num_channels_skip[i] != 0:
|
18 |
+
model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper))
|
19 |
+
else:
|
20 |
+
model_tmp.add_module(str(len(model_tmp) + 1), deeper)
|
21 |
+
|
22 |
+
model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i])))
|
23 |
+
|
24 |
+
if num_channels_skip[i] != 0:
|
25 |
+
skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False))
|
26 |
+
|
27 |
+
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False))
|
28 |
+
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False))
|
29 |
+
|
30 |
+
deeper_main = nn.Sequential()
|
31 |
+
|
32 |
+
if i == len(num_channels_down) - 1:
|
33 |
+
k = num_channels_down[i]
|
34 |
+
else:
|
35 |
+
deeper.add_module(str(len(deeper) + 1), deeper_main)
|
36 |
+
k = num_channels_up[i + 1]
|
37 |
+
|
38 |
+
deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest'))
|
39 |
+
|
40 |
+
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False))
|
41 |
+
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False))
|
42 |
+
|
43 |
+
input_depth = num_channels_down[i]
|
44 |
+
model_tmp = deeper_main
|
45 |
+
|
46 |
+
self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True))
|
47 |
+
self.model.add_module(str(len(self.model) + 1), nn.Sigmoid())
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.model(x)
|
51 |
+
|
52 |
+
|
53 |
+
def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10):
|
54 |
+
shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]]
|
55 |
+
return torch.rand(*shape) * scale
|
modules.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class DepthwiseSeperableConv2d(nn.Module):
|
6 |
+
def __init__(self, input_channels, output_channels, **kwargs):
|
7 |
+
super(DepthwiseSeperableConv2d, self).__init__()
|
8 |
+
|
9 |
+
self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs)
|
10 |
+
self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
x = self.depthwise(x)
|
14 |
+
x = self.pointwise(x)
|
15 |
+
|
16 |
+
return x
|
17 |
+
|
18 |
+
class Conv2dBlock(nn.Module):
|
19 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False):
|
20 |
+
super(Conv2dBlock, self).__init__()
|
21 |
+
|
22 |
+
self.model = nn.Sequential(
|
23 |
+
nn.ReflectionPad2d(int((kernel_size - 1) / 2)),
|
24 |
+
DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias),
|
25 |
+
nn.BatchNorm2d(out_channels),
|
26 |
+
nn.LeakyReLU(0.2)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return self.model(x)
|
31 |
+
|
32 |
+
class Concat(nn.Module):
|
33 |
+
def __init__(self, dim, *args):
|
34 |
+
super(Concat, self).__init__()
|
35 |
+
self.dim = dim
|
36 |
+
|
37 |
+
for idx, module in enumerate(args):
|
38 |
+
self.add_module(str(idx), module)
|
39 |
+
|
40 |
+
def forward(self, input):
|
41 |
+
inputs = []
|
42 |
+
for module in self._modules.values():
|
43 |
+
inputs.append(module(input))
|
44 |
+
|
45 |
+
inputs_shapes2 = [x.shape[2] for x in inputs]
|
46 |
+
inputs_shapes3 = [x.shape[3] for x in inputs]
|
47 |
+
|
48 |
+
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
|
49 |
+
inputs_ = inputs
|
50 |
+
else:
|
51 |
+
target_shape2 = min(inputs_shapes2)
|
52 |
+
target_shape3 = min(inputs_shapes3)
|
53 |
+
|
54 |
+
inputs_ = []
|
55 |
+
for inp in inputs:
|
56 |
+
diff2 = (inp.size(2) - target_shape2) // 2
|
57 |
+
diff3 = (inp.size(3) - target_shape3) // 2
|
58 |
+
inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
|
59 |
+
|
60 |
+
return torch.cat(inputs_, dim=self.dim)
|
61 |
+
|
62 |
+
def __len__(self):
|
63 |
+
return len(self._modules)
|