Arnaudding001 commited on
Commit
989f283
·
1 Parent(s): 9986e48

Create raft_core_extractor.py

Browse files
Files changed (1) hide show
  1. raft_core_extractor.py +267 -0
raft_core_extractor.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.relu(self.norm1(self.conv1(y)))
51
+ y = self.relu(self.norm2(self.conv2(y)))
52
+
53
+ if self.downsample is not None:
54
+ x = self.downsample(x)
55
+
56
+ return self.relu(x+y)
57
+
58
+
59
+
60
+ class BottleneckBlock(nn.Module):
61
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62
+ super(BottleneckBlock, self).__init__()
63
+
64
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67
+ self.relu = nn.ReLU(inplace=True)
68
+
69
+ num_groups = planes // 8
70
+
71
+ if norm_fn == 'group':
72
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75
+ if not stride == 1:
76
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77
+
78
+ elif norm_fn == 'batch':
79
+ self.norm1 = nn.BatchNorm2d(planes//4)
80
+ self.norm2 = nn.BatchNorm2d(planes//4)
81
+ self.norm3 = nn.BatchNorm2d(planes)
82
+ if not stride == 1:
83
+ self.norm4 = nn.BatchNorm2d(planes)
84
+
85
+ elif norm_fn == 'instance':
86
+ self.norm1 = nn.InstanceNorm2d(planes//4)
87
+ self.norm2 = nn.InstanceNorm2d(planes//4)
88
+ self.norm3 = nn.InstanceNorm2d(planes)
89
+ if not stride == 1:
90
+ self.norm4 = nn.InstanceNorm2d(planes)
91
+
92
+ elif norm_fn == 'none':
93
+ self.norm1 = nn.Sequential()
94
+ self.norm2 = nn.Sequential()
95
+ self.norm3 = nn.Sequential()
96
+ if not stride == 1:
97
+ self.norm4 = nn.Sequential()
98
+
99
+ if stride == 1:
100
+ self.downsample = None
101
+
102
+ else:
103
+ self.downsample = nn.Sequential(
104
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105
+
106
+
107
+ def forward(self, x):
108
+ y = x
109
+ y = self.relu(self.norm1(self.conv1(y)))
110
+ y = self.relu(self.norm2(self.conv2(y)))
111
+ y = self.relu(self.norm3(self.conv3(y)))
112
+
113
+ if self.downsample is not None:
114
+ x = self.downsample(x)
115
+
116
+ return self.relu(x+y)
117
+
118
+ class BasicEncoder(nn.Module):
119
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120
+ super(BasicEncoder, self).__init__()
121
+ self.norm_fn = norm_fn
122
+
123
+ if self.norm_fn == 'group':
124
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125
+
126
+ elif self.norm_fn == 'batch':
127
+ self.norm1 = nn.BatchNorm2d(64)
128
+
129
+ elif self.norm_fn == 'instance':
130
+ self.norm1 = nn.InstanceNorm2d(64)
131
+
132
+ elif self.norm_fn == 'none':
133
+ self.norm1 = nn.Sequential()
134
+
135
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136
+ self.relu1 = nn.ReLU(inplace=True)
137
+
138
+ self.in_planes = 64
139
+ self.layer1 = self._make_layer(64, stride=1)
140
+ self.layer2 = self._make_layer(96, stride=2)
141
+ self.layer3 = self._make_layer(128, stride=2)
142
+
143
+ # output convolution
144
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145
+
146
+ self.dropout = None
147
+ if dropout > 0:
148
+ self.dropout = nn.Dropout2d(p=dropout)
149
+
150
+ for m in self.modules():
151
+ if isinstance(m, nn.Conv2d):
152
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154
+ if m.weight is not None:
155
+ nn.init.constant_(m.weight, 1)
156
+ if m.bias is not None:
157
+ nn.init.constant_(m.bias, 0)
158
+
159
+ def _make_layer(self, dim, stride=1):
160
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162
+ layers = (layer1, layer2)
163
+
164
+ self.in_planes = dim
165
+ return nn.Sequential(*layers)
166
+
167
+
168
+ def forward(self, x):
169
+
170
+ # if input is list, combine batch dimension
171
+ is_list = isinstance(x, tuple) or isinstance(x, list)
172
+ if is_list:
173
+ batch_dim = x[0].shape[0]
174
+ x = torch.cat(x, dim=0)
175
+
176
+ x = self.conv1(x)
177
+ x = self.norm1(x)
178
+ x = self.relu1(x)
179
+
180
+ x = self.layer1(x)
181
+ x = self.layer2(x)
182
+ x = self.layer3(x)
183
+
184
+ x = self.conv2(x)
185
+
186
+ if self.training and self.dropout is not None:
187
+ x = self.dropout(x)
188
+
189
+ if is_list:
190
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
191
+
192
+ return x
193
+
194
+
195
+ class SmallEncoder(nn.Module):
196
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197
+ super(SmallEncoder, self).__init__()
198
+ self.norm_fn = norm_fn
199
+
200
+ if self.norm_fn == 'group':
201
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202
+
203
+ elif self.norm_fn == 'batch':
204
+ self.norm1 = nn.BatchNorm2d(32)
205
+
206
+ elif self.norm_fn == 'instance':
207
+ self.norm1 = nn.InstanceNorm2d(32)
208
+
209
+ elif self.norm_fn == 'none':
210
+ self.norm1 = nn.Sequential()
211
+
212
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213
+ self.relu1 = nn.ReLU(inplace=True)
214
+
215
+ self.in_planes = 32
216
+ self.layer1 = self._make_layer(32, stride=1)
217
+ self.layer2 = self._make_layer(64, stride=2)
218
+ self.layer3 = self._make_layer(96, stride=2)
219
+
220
+ self.dropout = None
221
+ if dropout > 0:
222
+ self.dropout = nn.Dropout2d(p=dropout)
223
+
224
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225
+
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Conv2d):
228
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230
+ if m.weight is not None:
231
+ nn.init.constant_(m.weight, 1)
232
+ if m.bias is not None:
233
+ nn.init.constant_(m.bias, 0)
234
+
235
+ def _make_layer(self, dim, stride=1):
236
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238
+ layers = (layer1, layer2)
239
+
240
+ self.in_planes = dim
241
+ return nn.Sequential(*layers)
242
+
243
+
244
+ def forward(self, x):
245
+
246
+ # if input is list, combine batch dimension
247
+ is_list = isinstance(x, tuple) or isinstance(x, list)
248
+ if is_list:
249
+ batch_dim = x[0].shape[0]
250
+ x = torch.cat(x, dim=0)
251
+
252
+ x = self.conv1(x)
253
+ x = self.norm1(x)
254
+ x = self.relu1(x)
255
+
256
+ x = self.layer1(x)
257
+ x = self.layer2(x)
258
+ x = self.layer3(x)
259
+ x = self.conv2(x)
260
+
261
+ if self.training and self.dropout is not None:
262
+ x = self.dropout(x)
263
+
264
+ if is_list:
265
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
266
+
267
+ return x