handsomeboyMMk commited on
Commit
269009c
·
1 Parent(s): d984345

Upload 2 files

Browse files
Files changed (2) hide show
  1. generator.py +55 -0
  2. modules.py +63 -0
generator.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .modules import Conv2dBlock, Concat
4
+
5
+ class SkipEncoderDecoder(nn.Module):
6
+ def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5):
7
+ super(SkipEncoderDecoder, self).__init__()
8
+
9
+ self.model = nn.Sequential()
10
+ model_tmp = self.model
11
+
12
+ for i in range(len(num_channels_down)):
13
+
14
+ deeper = nn.Sequential()
15
+ skip = nn.Sequential()
16
+
17
+ if num_channels_skip[i] != 0:
18
+ model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper))
19
+ else:
20
+ model_tmp.add_module(str(len(model_tmp) + 1), deeper)
21
+
22
+ model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i])))
23
+
24
+ if num_channels_skip[i] != 0:
25
+ skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False))
26
+
27
+ deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False))
28
+ deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False))
29
+
30
+ deeper_main = nn.Sequential()
31
+
32
+ if i == len(num_channels_down) - 1:
33
+ k = num_channels_down[i]
34
+ else:
35
+ deeper.add_module(str(len(deeper) + 1), deeper_main)
36
+ k = num_channels_up[i + 1]
37
+
38
+ deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest'))
39
+
40
+ model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False))
41
+ model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False))
42
+
43
+ input_depth = num_channels_down[i]
44
+ model_tmp = deeper_main
45
+
46
+ self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True))
47
+ self.model.add_module(str(len(self.model) + 1), nn.Sigmoid())
48
+
49
+ def forward(self, x):
50
+ return self.model(x)
51
+
52
+
53
+ def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10):
54
+ shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]]
55
+ return torch.rand(*shape) * scale
modules.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+
5
+ class DepthwiseSeperableConv2d(nn.Module):
6
+ def __init__(self, input_channels, output_channels, **kwargs):
7
+ super(DepthwiseSeperableConv2d, self).__init__()
8
+
9
+ self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs)
10
+ self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1)
11
+
12
+ def forward(self, x):
13
+ x = self.depthwise(x)
14
+ x = self.pointwise(x)
15
+
16
+ return x
17
+
18
+ class Conv2dBlock(nn.Module):
19
+ def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False):
20
+ super(Conv2dBlock, self).__init__()
21
+
22
+ self.model = nn.Sequential(
23
+ nn.ReflectionPad2d(int((kernel_size - 1) / 2)),
24
+ DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias),
25
+ nn.BatchNorm2d(out_channels),
26
+ nn.LeakyReLU(0.2)
27
+ )
28
+
29
+ def forward(self, x):
30
+ return self.model(x)
31
+
32
+ class Concat(nn.Module):
33
+ def __init__(self, dim, *args):
34
+ super(Concat, self).__init__()
35
+ self.dim = dim
36
+
37
+ for idx, module in enumerate(args):
38
+ self.add_module(str(idx), module)
39
+
40
+ def forward(self, input):
41
+ inputs = []
42
+ for module in self._modules.values():
43
+ inputs.append(module(input))
44
+
45
+ inputs_shapes2 = [x.shape[2] for x in inputs]
46
+ inputs_shapes3 = [x.shape[3] for x in inputs]
47
+
48
+ if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
49
+ inputs_ = inputs
50
+ else:
51
+ target_shape2 = min(inputs_shapes2)
52
+ target_shape3 = min(inputs_shapes3)
53
+
54
+ inputs_ = []
55
+ for inp in inputs:
56
+ diff2 = (inp.size(2) - target_shape2) // 2
57
+ diff3 = (inp.size(3) - target_shape3) // 2
58
+ inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
59
+
60
+ return torch.cat(inputs_, dim=self.dim)
61
+
62
+ def __len__(self):
63
+ return len(self._modules)