akhaliq HF staff commited on
Commit
0870534
·
1 Parent(s): 1c296ac
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Intelligent Systems Lab Org
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
additional_utils/encoding_models.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # Referred to: https://github.com/zhanghang1989/PyTorch-Encoding
3
+ ###########################################################################
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.parallel.data_parallel import DataParallel
11
+ from torch.nn.parallel.scatter_gather import scatter
12
+ import threading
13
+ import torch
14
+ from torch.cuda._utils import _get_device_index
15
+ from torch.cuda.amp import autocast
16
+ from torch._utils import ExceptionWrapper
17
+
18
+ up_kwargs = {'mode': 'bilinear', 'align_corners': True}
19
+
20
+ __all__ = ['MultiEvalModule']
21
+
22
+ class MultiEvalModule(DataParallel):
23
+ """Multi-size Segmentation Eavluator"""
24
+ def __init__(self, module, nclass, device_ids=None, flip=True,
25
+ scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
26
+ super(MultiEvalModule, self).__init__(module, device_ids)
27
+ self.nclass = nclass
28
+ self.base_size = module.base_size
29
+ self.crop_size = module.crop_size
30
+ self.scales = scales
31
+ self.flip = flip
32
+ print('MultiEvalModule: base_size {}, crop_size {}'. \
33
+ format(self.base_size, self.crop_size))
34
+
35
+ def parallel_forward(self, inputs, **kwargs):
36
+ """Multi-GPU Mult-size Evaluation
37
+
38
+ Args:
39
+ inputs: list of Tensors
40
+ """
41
+ inputs = [(input.unsqueeze(0).cuda(device),)
42
+ for input, device in zip(inputs, self.device_ids)]
43
+ replicas = self.replicate(self, self.device_ids[:len(inputs)])
44
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
45
+ if len(inputs) < len(kwargs):
46
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
47
+ elif len(kwargs) < len(inputs):
48
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
49
+ outputs = self.parallel_apply(replicas, inputs, kwargs)
50
+ #for out in outputs:
51
+ # print('out.size()', out.size())
52
+ return outputs
53
+
54
+ def forward(self, image):
55
+ """Mult-size Evaluation"""
56
+ # only single image is supported for evaluation
57
+ batch, _, h, w = image.size()
58
+ assert(batch == 1)
59
+ stride_rate = 2.0/3.0
60
+ crop_size = self.crop_size
61
+ stride = int(crop_size * stride_rate)
62
+ with torch.cuda.device_of(image):
63
+ scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
64
+
65
+ for scale in self.scales:
66
+ long_size = int(math.ceil(self.base_size * scale))
67
+ if h > w:
68
+ height = long_size
69
+ width = int(1.0 * w * long_size / h + 0.5)
70
+ short_size = width
71
+ else:
72
+ width = long_size
73
+ height = int(1.0 * h * long_size / w + 0.5)
74
+ short_size = height
75
+ """
76
+ short_size = int(math.ceil(self.base_size * scale))
77
+ if h > w:
78
+ width = short_size
79
+ height = int(1.0 * h * short_size / w)
80
+ long_size = height
81
+ else:
82
+ height = short_size
83
+ width = int(1.0 * w * short_size / h)
84
+ long_size = width
85
+ """
86
+ # resize image to current size
87
+ cur_img = resize_image(image, height, width, **self.module._up_kwargs)
88
+ if long_size <= crop_size:
89
+ pad_img = pad_image(cur_img, self.module.mean,
90
+ self.module.std, crop_size)
91
+ outputs = module_inference(self.module, pad_img, self.flip)
92
+ outputs = crop_image(outputs, 0, height, 0, width)
93
+ else:
94
+ if short_size < crop_size:
95
+ # pad if needed
96
+ pad_img = pad_image(cur_img, self.module.mean,
97
+ self.module.std, crop_size)
98
+ else:
99
+ pad_img = cur_img
100
+ _,_,ph,pw = pad_img.size()
101
+ assert(ph >= height and pw >= width)
102
+ # grid forward and normalize
103
+ h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
104
+ w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
105
+ with torch.cuda.device_of(image):
106
+ outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
107
+ count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
108
+ # grid evaluation
109
+ for idh in range(h_grids):
110
+ for idw in range(w_grids):
111
+ h0 = idh * stride
112
+ w0 = idw * stride
113
+ h1 = min(h0 + crop_size, ph)
114
+ w1 = min(w0 + crop_size, pw)
115
+ crop_img = crop_image(pad_img, h0, h1, w0, w1)
116
+ # pad if needed
117
+ pad_crop_img = pad_image(crop_img, self.module.mean,
118
+ self.module.std, crop_size)
119
+ output = module_inference(self.module, pad_crop_img, self.flip)
120
+ outputs[:,:,h0:h1,w0:w1] += crop_image(output,
121
+ 0, h1-h0, 0, w1-w0)
122
+ count_norm[:,:,h0:h1,w0:w1] += 1
123
+ assert((count_norm==0).sum()==0)
124
+ outputs = outputs / count_norm
125
+ outputs = outputs[:,:,:height,:width]
126
+
127
+ score = resize_image(outputs, h, w, **self.module._up_kwargs)
128
+ scores += score
129
+
130
+ return scores
131
+
132
+
133
+ def module_inference(module, image, flip=True):
134
+ output = module.evaluate(image)
135
+ if flip:
136
+ fimg = flip_image(image)
137
+ foutput = module.evaluate(fimg)
138
+ output += flip_image(foutput)
139
+ return output
140
+
141
+ def resize_image(img, h, w, **up_kwargs):
142
+ return F.interpolate(img, (h, w), **up_kwargs)
143
+
144
+ def pad_image(img, mean, std, crop_size):
145
+ b,c,h,w = img.size()
146
+ assert(c==3)
147
+ padh = crop_size - h if h < crop_size else 0
148
+ padw = crop_size - w if w < crop_size else 0
149
+ pad_values = -np.array(mean) / np.array(std)
150
+ img_pad = img.new().resize_(b,c,h+padh,w+padw)
151
+ for i in range(c):
152
+ # note that pytorch pad params is in reversed orders
153
+ img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
154
+ assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
155
+ return img_pad
156
+
157
+ def crop_image(img, h0, h1, w0, w1):
158
+ return img[:,:,h0:h1,w0:w1]
159
+
160
+ def flip_image(img):
161
+ assert(img.dim()==4)
162
+ with torch.cuda.device_of(img):
163
+ idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
164
+ return img.index_select(3, idx)
additional_utils/models.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # Referred to: https://github.com/zhanghang1989/PyTorch-Encoding
3
+ ###########################################################################
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.parallel.data_parallel import DataParallel
11
+ from torch.nn.parallel.scatter_gather import scatter
12
+ import threading
13
+ import torch
14
+ from torch.cuda._utils import _get_device_index
15
+ from torch.cuda.amp import autocast
16
+ from torch._utils import ExceptionWrapper
17
+
18
+ up_kwargs = {'mode': 'bilinear', 'align_corners': True}
19
+
20
+ __all__ = ['LSeg_MultiEvalModule']
21
+
22
+
23
+ class LSeg_MultiEvalModule(DataParallel):
24
+ """Multi-size Segmentation Eavluator"""
25
+ def __init__(self, module, device_ids=None, flip=True,
26
+ scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
27
+ super(LSeg_MultiEvalModule, self).__init__(module, device_ids)
28
+ self.base_size = module.base_size
29
+ self.crop_size = module.crop_size
30
+ self.scales = scales
31
+ self.flip = flip
32
+ print('MultiEvalModule: base_size {}, crop_size {}'. \
33
+ format(self.base_size, self.crop_size))
34
+
35
+ def parallel_forward(self, inputs, label_set='', **kwargs):
36
+ """Multi-GPU Mult-size Evaluation
37
+
38
+ Args:
39
+ inputs: list of Tensors
40
+ """
41
+ if len(label_set) < 10:
42
+ print('** MultiEvalModule parallel_forward phase: {} **'.format(label_set))
43
+ self.nclass = len(label_set)
44
+ inputs = [(input.unsqueeze(0).cuda(device),)
45
+ for input, device in zip(inputs, self.device_ids)]
46
+ replicas = self.replicate(self, self.device_ids[:len(inputs)])
47
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
48
+ if len(inputs) < len(kwargs):
49
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
50
+ elif len(kwargs) < len(inputs):
51
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
52
+ outputs = parallel_apply(replicas, inputs, label_set, kwargs)
53
+ return outputs
54
+
55
+ def forward(self, image, label_set=''):
56
+ """Mult-size Evaluation"""
57
+ # only single image is supported for evaluation
58
+ if len(label_set) < 10:
59
+ print('** MultiEvalModule forward phase: {} **'.format(label_set))
60
+ batch, _, h, w = image.size()
61
+ assert(batch == 1)
62
+ self.nclass = len(label_set)
63
+ stride_rate = 2.0/3.0
64
+ crop_size = self.crop_size
65
+ stride = int(crop_size * stride_rate)
66
+ with torch.cuda.device_of(image):
67
+ scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
68
+
69
+ for scale in self.scales:
70
+ long_size = int(math.ceil(self.base_size * scale))
71
+ if h > w:
72
+ height = long_size
73
+ width = int(1.0 * w * long_size / h + 0.5)
74
+ short_size = width
75
+ else:
76
+ width = long_size
77
+ height = int(1.0 * h * long_size / w + 0.5)
78
+ short_size = height
79
+ """
80
+ short_size = int(math.ceil(self.base_size * scale))
81
+ if h > w:
82
+ width = short_size
83
+ height = int(1.0 * h * short_size / w)
84
+ long_size = height
85
+ else:
86
+ height = short_size
87
+ width = int(1.0 * w * short_size / h)
88
+ long_size = width
89
+ """
90
+ # resize image to current size
91
+ cur_img = resize_image(image, height, width, **self.module._up_kwargs)
92
+ if long_size <= crop_size:
93
+ pad_img = pad_image(cur_img, self.module.mean,
94
+ self.module.std, crop_size)
95
+ outputs = module_inference(self.module, pad_img, label_set, self.flip)
96
+ outputs = crop_image(outputs, 0, height, 0, width)
97
+ else:
98
+ if short_size < crop_size:
99
+ # pad if needed
100
+ pad_img = pad_image(cur_img, self.module.mean,
101
+ self.module.std, crop_size)
102
+ else:
103
+ pad_img = cur_img
104
+ _,_,ph,pw = pad_img.shape #.size()
105
+ assert(ph >= height and pw >= width)
106
+ # grid forward and normalize
107
+ h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
108
+ w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
109
+ with torch.cuda.device_of(image):
110
+ outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
111
+ count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
112
+ # grid evaluation
113
+ for idh in range(h_grids):
114
+ for idw in range(w_grids):
115
+ h0 = idh * stride
116
+ w0 = idw * stride
117
+ h1 = min(h0 + crop_size, ph)
118
+ w1 = min(w0 + crop_size, pw)
119
+ crop_img = crop_image(pad_img, h0, h1, w0, w1)
120
+ # pad if needed
121
+ pad_crop_img = pad_image(crop_img, self.module.mean,
122
+ self.module.std, crop_size)
123
+ output = module_inference(self.module, pad_crop_img, label_set, self.flip)
124
+ outputs[:,:,h0:h1,w0:w1] += crop_image(output,
125
+ 0, h1-h0, 0, w1-w0)
126
+ count_norm[:,:,h0:h1,w0:w1] += 1
127
+ assert((count_norm==0).sum()==0)
128
+ outputs = outputs / count_norm
129
+ outputs = outputs[:,:,:height,:width]
130
+ score = resize_image(outputs, h, w, **self.module._up_kwargs)
131
+ scores += score
132
+ return scores
133
+
134
+ def module_inference(module, image, label_set, flip=True):
135
+ output = module.evaluate_random(image, label_set)
136
+ if flip:
137
+ fimg = flip_image(image)
138
+ foutput = module.evaluate_random(fimg, label_set)
139
+ output += flip_image(foutput)
140
+ return output
141
+
142
+ def resize_image(img, h, w, **up_kwargs):
143
+ return F.interpolate(img, (h, w), **up_kwargs)
144
+
145
+ def pad_image(img, mean, std, crop_size):
146
+ b,c,h,w = img.shape #.size()
147
+ assert(c==3)
148
+ padh = crop_size - h if h < crop_size else 0
149
+ padw = crop_size - w if w < crop_size else 0
150
+ pad_values = -np.array(mean) / np.array(std)
151
+ img_pad = img.new().resize_(b,c,h+padh,w+padw)
152
+ for i in range(c):
153
+ # note that pytorch pad params is in reversed orders
154
+ img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
155
+ assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
156
+ return img_pad
157
+
158
+ def crop_image(img, h0, h1, w0, w1):
159
+ return img[:,:,h0:h1,w0:w1]
160
+
161
+ def flip_image(img):
162
+ assert(img.dim()==4)
163
+ with torch.cuda.device_of(img):
164
+ idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
165
+ return img.index_select(3, idx)
166
+
167
+
168
+ def get_a_var(obj):
169
+ if isinstance(obj, torch.Tensor):
170
+ return obj
171
+
172
+ if isinstance(obj, list) or isinstance(obj, tuple):
173
+ for result in map(get_a_var, obj):
174
+ if isinstance(result, torch.Tensor):
175
+ return result
176
+ if isinstance(obj, dict):
177
+ for result in map(get_a_var, obj.items()):
178
+ if isinstance(result, torch.Tensor):
179
+ return result
180
+ return None
181
+
182
+
183
+ def parallel_apply(modules, inputs, label_set, kwargs_tup=None, devices=None):
184
+ r"""Applies each `module` in :attr:`modules` in parallel on arguments
185
+ contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
186
+ on each of :attr:`devices`.
187
+
188
+ Args:
189
+ modules (Module): modules to be parallelized
190
+ inputs (tensor): inputs to the modules
191
+ devices (list of int or torch.device): CUDA devices
192
+
193
+ :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
194
+ :attr:`devices` (if given) should all have same length. Moreover, each
195
+ element of :attr:`inputs` can either be a single object as the only argument
196
+ to a module, or a collection of positional arguments.
197
+ """
198
+ assert len(modules) == len(inputs)
199
+ if kwargs_tup is not None:
200
+ assert len(modules) == len(kwargs_tup)
201
+ else:
202
+ kwargs_tup = ({},) * len(modules)
203
+ if devices is not None:
204
+ assert len(modules) == len(devices)
205
+ else:
206
+ devices = [None] * len(modules)
207
+ devices = [_get_device_index(x, True) for x in devices]
208
+ lock = threading.Lock()
209
+ results = {}
210
+ grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
211
+
212
+ def _worker(i, module, input, label_set, kwargs, device=None):
213
+ torch.set_grad_enabled(grad_enabled)
214
+ if device is None:
215
+ device = get_a_var(input).get_device()
216
+ try:
217
+ with torch.cuda.device(device), autocast(enabled=autocast_enabled):
218
+ # this also avoids accidental slicing of `input` if it is a Tensor
219
+ if not isinstance(input, (list, tuple)):
220
+ input = (input,)
221
+ output = module(*input, label_set, **kwargs)
222
+ with lock:
223
+ results[i] = output
224
+ except Exception:
225
+ with lock:
226
+ results[i] = ExceptionWrapper(
227
+ where="in replica {} on device {}".format(i, device))
228
+
229
+ if len(modules) > 1:
230
+ threads = [threading.Thread(target=_worker,
231
+ args=(i, module, input, label_set, kwargs, device))
232
+ for i, (module, input, kwargs, device) in
233
+ enumerate(zip(modules, inputs, kwargs_tup, devices))]
234
+
235
+ for thread in threads:
236
+ thread.start()
237
+ for thread in threads:
238
+ thread.join()
239
+ else:
240
+ _worker(0, modules[0], inputs[0], label_set, kwargs_tup[0], devices[0])
241
+
242
+ outputs = []
243
+ for i in range(len(inputs)):
244
+ output = results[i]
245
+ if isinstance(output, ExceptionWrapper):
246
+ output.reraise()
247
+ outputs.append(output)
248
+ return outputs
249
+
250
+
data/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import itertools
4
+ import functools
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data
8
+ import torchvision.transforms as torch_transforms
9
+ import encoding.datasets as enc_ds
10
+
11
+ encoding_datasets = {
12
+ x: functools.partial(enc_ds.get_dataset, x)
13
+ for x in ["coco", "ade20k", "pascal_voc", "pascal_aug", "pcontext", "citys"]
14
+ }
15
+
16
+
17
+ def get_dataset(name, **kwargs):
18
+ if name in encoding_datasets:
19
+ return encoding_datasets[name.lower()](**kwargs)
20
+ assert False, f"dataset {name} not found"
21
+
22
+
23
+ def get_available_datasets():
24
+ return list(encoding_datasets.keys())
label_files/ade20k_objectInfo150.txt ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Idx,Ratio,Train,Val,Stuff,Name
2
+ 1,0.1576,11664,1172,1,wall
3
+ 2,0.1072,6046,612,1,building;edifice
4
+ 3,0.0878,8265,796,1,sky
5
+ 4,0.0621,9336,917,1,floor;flooring
6
+ 5,0.0480,6678,641,0,tree
7
+ 6,0.0450,6604,643,1,ceiling
8
+ 7,0.0398,4023,408,1,road;route
9
+ 8,0.0231,1906,199,0,bed
10
+ 9,0.0198,4688,460,0,windowpane;window
11
+ 10,0.0183,2423,225,1,grass
12
+ 11,0.0181,2874,294,0,cabinet
13
+ 12,0.0166,3068,310,1,sidewalk;pavement
14
+ 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
15
+ 14,0.0151,1804,190,1,earth;ground
16
+ 15,0.0118,6666,796,0,door;double;door
17
+ 16,0.0110,4269,411,0,table
18
+ 17,0.0109,1691,160,1,mountain;mount
19
+ 18,0.0104,3999,441,0,plant;flora;plant;life
20
+ 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
21
+ 20,0.0103,3261,318,0,chair
22
+ 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
23
+ 22,0.0074,709,75,1,water
24
+ 23,0.0067,3296,315,0,painting;picture
25
+ 24,0.0065,1191,106,0,sofa;couch;lounge
26
+ 25,0.0061,1516,162,0,shelf
27
+ 26,0.0060,667,69,1,house
28
+ 27,0.0053,651,57,1,sea
29
+ 28,0.0052,1847,224,0,mirror
30
+ 29,0.0046,1158,128,1,rug;carpet;carpeting
31
+ 30,0.0044,480,44,1,field
32
+ 31,0.0044,1172,98,0,armchair
33
+ 32,0.0044,1292,184,0,seat
34
+ 33,0.0033,1386,138,0,fence;fencing
35
+ 34,0.0031,698,61,0,desk
36
+ 35,0.0030,781,73,0,rock;stone
37
+ 36,0.0027,380,43,0,wardrobe;closet;press
38
+ 37,0.0026,3089,302,0,lamp
39
+ 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
40
+ 39,0.0024,804,99,0,railing;rail
41
+ 40,0.0023,1453,153,0,cushion
42
+ 41,0.0023,411,37,0,base;pedestal;stand
43
+ 42,0.0022,1440,162,0,box
44
+ 43,0.0022,800,77,0,column;pillar
45
+ 44,0.0020,2650,298,0,signboard;sign
46
+ 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
47
+ 46,0.0019,367,36,0,counter
48
+ 47,0.0018,311,30,1,sand
49
+ 48,0.0018,1181,122,0,sink
50
+ 49,0.0018,287,23,1,skyscraper
51
+ 50,0.0018,468,38,0,fireplace;hearth;open;fireplace
52
+ 51,0.0018,402,43,0,refrigerator;icebox
53
+ 52,0.0018,130,12,1,grandstand;covered;stand
54
+ 53,0.0018,561,64,1,path
55
+ 54,0.0017,880,102,0,stairs;steps
56
+ 55,0.0017,86,12,1,runway
57
+ 56,0.0017,172,11,0,case;display;case;showcase;vitrine
58
+ 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
59
+ 58,0.0017,930,109,0,pillow
60
+ 59,0.0015,139,18,0,screen;door;screen
61
+ 60,0.0015,564,52,1,stairway;staircase
62
+ 61,0.0015,320,26,1,river
63
+ 62,0.0015,261,29,1,bridge;span
64
+ 63,0.0014,275,22,0,bookcase
65
+ 64,0.0014,335,60,0,blind;screen
66
+ 65,0.0014,792,75,0,coffee;table;cocktail;table
67
+ 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
68
+ 67,0.0014,1309,138,0,flower
69
+ 68,0.0013,1112,113,0,book
70
+ 69,0.0013,266,27,1,hill
71
+ 70,0.0013,659,66,0,bench
72
+ 71,0.0012,331,31,0,countertop
73
+ 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
74
+ 73,0.0012,369,36,0,palm;palm;tree
75
+ 74,0.0012,144,9,0,kitchen;island
76
+ 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
77
+ 76,0.0010,324,33,0,swivel;chair
78
+ 77,0.0009,304,27,0,boat
79
+ 78,0.0009,170,20,0,bar
80
+ 79,0.0009,68,6,0,arcade;machine
81
+ 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
82
+ 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
83
+ 82,0.0008,492,49,0,towel
84
+ 83,0.0008,2510,269,0,light;light;source
85
+ 84,0.0008,440,39,0,truck;motortruck
86
+ 85,0.0008,147,18,1,tower
87
+ 86,0.0008,583,56,0,chandelier;pendant;pendent
88
+ 87,0.0007,533,61,0,awning;sunshade;sunblind
89
+ 88,0.0007,1989,239,0,streetlight;street;lamp
90
+ 89,0.0007,71,5,0,booth;cubicle;stall;kiosk
91
+ 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
92
+ 91,0.0007,135,12,0,airplane;aeroplane;plane
93
+ 92,0.0007,83,5,1,dirt;track
94
+ 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
95
+ 94,0.0006,1003,104,0,pole
96
+ 95,0.0006,182,12,1,land;ground;soil
97
+ 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
98
+ 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
99
+ 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
100
+ 99,0.0006,965,114,0,bottle
101
+ 100,0.0006,117,13,0,buffet;counter;sideboard
102
+ 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
103
+ 102,0.0006,108,9,1,stage
104
+ 103,0.0006,557,55,0,van
105
+ 104,0.0006,52,4,0,ship
106
+ 105,0.0005,99,5,0,fountain
107
+ 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
108
+ 107,0.0005,292,31,0,canopy
109
+ 108,0.0005,77,9,0,washer;automatic;washer;washing;machine
110
+ 109,0.0005,340,38,0,plaything;toy
111
+ 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
112
+ 111,0.0005,465,49,0,stool
113
+ 112,0.0005,50,4,0,barrel;cask
114
+ 113,0.0005,622,75,0,basket;handbasket
115
+ 114,0.0005,80,9,1,waterfall;falls
116
+ 115,0.0005,59,3,0,tent;collapsible;shelter
117
+ 116,0.0005,531,72,0,bag
118
+ 117,0.0005,282,30,0,minibike;motorbike
119
+ 118,0.0005,73,7,0,cradle
120
+ 119,0.0005,435,44,0,oven
121
+ 120,0.0005,136,25,0,ball
122
+ 121,0.0005,116,24,0,food;solid;food
123
+ 122,0.0004,266,31,0,step;stair
124
+ 123,0.0004,58,12,0,tank;storage;tank
125
+ 124,0.0004,418,83,0,trade;name;brand;name;brand;marque
126
+ 125,0.0004,319,43,0,microwave;microwave;oven
127
+ 126,0.0004,1193,139,0,pot;flowerpot
128
+ 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
129
+ 128,0.0004,347,36,0,bicycle;bike;wheel;cycle
130
+ 129,0.0004,52,5,1,lake
131
+ 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
132
+ 131,0.0004,108,13,0,screen;silver;screen;projection;screen
133
+ 132,0.0004,201,30,0,blanket;cover
134
+ 133,0.0004,285,21,0,sculpture
135
+ 134,0.0004,268,27,0,hood;exhaust;hood
136
+ 135,0.0003,1020,108,0,sconce
137
+ 136,0.0003,1282,122,0,vase
138
+ 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
139
+ 138,0.0003,453,57,0,tray
140
+ 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
141
+ 140,0.0003,397,44,0,fan
142
+ 141,0.0003,92,8,1,pier;wharf;wharfage;dock
143
+ 142,0.0003,228,18,0,crt;screen
144
+ 143,0.0003,570,59,0,plate
145
+ 144,0.0003,217,22,0,monitor;monitoring;device
146
+ 145,0.0003,206,19,0,bulletin;board;notice;board
147
+ 146,0.0003,130,14,0,shower
148
+ 147,0.0003,178,28,0,radiator
149
+ 148,0.0002,504,57,0,glass;drinking;glass
150
+ 149,0.0002,775,96,0,clock
151
+ 150,0.0002,421,56,0,flag
lseg_app.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import altair as alt
3
+ import math
4
+ import pandas as pd
5
+ import streamlit as st
6
+ st.set_page_config(layout="wide")
7
+
8
+ from PIL import Image
9
+
10
+ import os
11
+ import torch
12
+
13
+ import os
14
+ import argparse
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+ from collections import OrderedDict
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch.utils import data
22
+ import torchvision.transforms as transform
23
+ from torch.nn.parallel.scatter_gather import gather
24
+
25
+ from additional_utils.models import LSeg_MultiEvalModule
26
+ from modules.lseg_module import LSegModule
27
+
28
+ import cv2
29
+ import math
30
+ import types
31
+ import functools
32
+ import torchvision.transforms as torch_transforms
33
+ import copy
34
+ import itertools
35
+ from PIL import Image
36
+ import matplotlib.pyplot as plt
37
+ import clip
38
+ from encoding.models.sseg import BaseNet
39
+ import matplotlib as mpl
40
+ import matplotlib.colors as mplc
41
+ import matplotlib.figure as mplfigure
42
+ import matplotlib.patches as mpatches
43
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
44
+ from data import get_dataset
45
+ import torchvision.transforms as transforms
46
+
47
+
48
+ def get_new_pallete(num_cls):
49
+ n = num_cls
50
+ pallete = [0]*(n*3)
51
+ for j in range(0,n):
52
+ lab = j
53
+ pallete[j*3+0] = 0
54
+ pallete[j*3+1] = 0
55
+ pallete[j*3+2] = 0
56
+ i = 0
57
+ while (lab > 0):
58
+ pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
59
+ pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
60
+ pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
61
+ i = i + 1
62
+ lab >>= 3
63
+ return pallete
64
+
65
+ def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None):
66
+ """Get image color pallete for visualizing masks"""
67
+ # put colormap
68
+ out_img = Image.fromarray(npimg.squeeze().astype('uint8'))
69
+ out_img.putpalette(new_palette)
70
+
71
+ if out_label_flag:
72
+ assert labels is not None
73
+ u_index = np.unique(npimg)
74
+ patches = []
75
+ for i, index in enumerate(u_index):
76
+ label = labels[index]
77
+ cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0]
78
+ red_patch = mpatches.Patch(color=cur_color, label=label)
79
+ patches.append(red_patch)
80
+ return out_img, patches
81
+
82
+ @st.cache(allow_output_mutation=True)
83
+ def load_model():
84
+ class Options:
85
+ def __init__(self):
86
+ parser = argparse.ArgumentParser(description="PyTorch Segmentation")
87
+ # model and dataset
88
+ parser.add_argument(
89
+ "--model", type=str, default="encnet", help="model name (default: encnet)"
90
+ )
91
+ parser.add_argument(
92
+ "--backbone",
93
+ type=str,
94
+ default="clip_vitl16_384",
95
+ help="backbone name (default: resnet50)",
96
+ )
97
+ parser.add_argument(
98
+ "--dataset",
99
+ type=str,
100
+ default="ade20k",
101
+ help="dataset name (default: pascal12)",
102
+ )
103
+ parser.add_argument(
104
+ "--workers", type=int, default=16, metavar="N", help="dataloader threads"
105
+ )
106
+ parser.add_argument(
107
+ "--base-size", type=int, default=520, help="base image size"
108
+ )
109
+ parser.add_argument(
110
+ "--crop-size", type=int, default=480, help="crop image size"
111
+ )
112
+ parser.add_argument(
113
+ "--train-split",
114
+ type=str,
115
+ default="train",
116
+ help="dataset train split (default: train)",
117
+ )
118
+ parser.add_argument(
119
+ "--aux", action="store_true", default=False, help="Auxilary Loss"
120
+ )
121
+ parser.add_argument(
122
+ "--se-loss",
123
+ action="store_true",
124
+ default=False,
125
+ help="Semantic Encoding Loss SE-loss",
126
+ )
127
+ parser.add_argument(
128
+ "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
129
+ )
130
+ parser.add_argument(
131
+ "--batch-size",
132
+ type=int,
133
+ default=16,
134
+ metavar="N",
135
+ help="input batch size for \
136
+ training (default: auto)",
137
+ )
138
+ parser.add_argument(
139
+ "--test-batch-size",
140
+ type=int,
141
+ default=16,
142
+ metavar="N",
143
+ help="input batch size for \
144
+ testing (default: same as batch size)",
145
+ )
146
+ # cuda, seed and logging
147
+ parser.add_argument(
148
+ "--no-cuda",
149
+ action="store_true",
150
+ default=False,
151
+ help="disables CUDA training",
152
+ )
153
+ parser.add_argument(
154
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
155
+ )
156
+ # checking point
157
+ parser.add_argument(
158
+ "--weights", type=str, default='', help="checkpoint to test"
159
+ )
160
+ # evaluation option
161
+ parser.add_argument(
162
+ "--eval", action="store_true", default=False, help="evaluating mIoU"
163
+ )
164
+ parser.add_argument(
165
+ "--export",
166
+ type=str,
167
+ default=None,
168
+ help="put the path to resuming file if needed",
169
+ )
170
+ parser.add_argument(
171
+ "--acc-bn",
172
+ action="store_true",
173
+ default=False,
174
+ help="Re-accumulate BN statistics",
175
+ )
176
+ parser.add_argument(
177
+ "--test-val",
178
+ action="store_true",
179
+ default=False,
180
+ help="generate masks on val set",
181
+ )
182
+ parser.add_argument(
183
+ "--no-val",
184
+ action="store_true",
185
+ default=False,
186
+ help="skip validation during training",
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--module",
191
+ default='lseg',
192
+ help="select model definition",
193
+ )
194
+
195
+ # test option
196
+ parser.add_argument(
197
+ "--data-path", type=str, default='../datasets/', help="path to test image folder"
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--no-scaleinv",
202
+ dest="scale_inv",
203
+ default=True,
204
+ action="store_false",
205
+ help="turn off scaleinv layers",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--widehead", default=False, action="store_true", help="wider output head"
210
+ )
211
+
212
+ parser.add_argument(
213
+ "--widehead_hr",
214
+ default=False,
215
+ action="store_true",
216
+ help="wider output head",
217
+ )
218
+ parser.add_argument(
219
+ "--ignore_index",
220
+ type=int,
221
+ default=-1,
222
+ help="numeric value of ignore label in gt",
223
+ )
224
+
225
+ parser.add_argument(
226
+ "--label_src",
227
+ type=str,
228
+ default="default",
229
+ help="how to get the labels",
230
+ )
231
+
232
+ parser.add_argument(
233
+ "--arch_option",
234
+ type=int,
235
+ default=0,
236
+ help="which kind of architecture to be used",
237
+ )
238
+
239
+ parser.add_argument(
240
+ "--block_depth",
241
+ type=int,
242
+ default=0,
243
+ help="how many blocks should be used",
244
+ )
245
+
246
+ parser.add_argument(
247
+ "--activation",
248
+ choices=['lrelu', 'tanh'],
249
+ default="lrelu",
250
+ help="use which activation to activate the block",
251
+ )
252
+
253
+ self.parser = parser
254
+
255
+ def parse(self):
256
+ args = self.parser.parse_args(args=[])
257
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
258
+ print(args)
259
+ return args
260
+
261
+ args = Options().parse()
262
+
263
+ torch.manual_seed(args.seed)
264
+ args.test_batch_size = 1
265
+ alpha=0.5
266
+
267
+ args.scale_inv = False
268
+ args.widehead = True
269
+ args.dataset = 'ade20k'
270
+ args.backbone = 'clip_vitl16_384'
271
+ args.weights = 'checkpoints/demo_e200.ckpt'
272
+ args.ignore_index = 255
273
+
274
+ module = LSegModule.load_from_checkpoint(
275
+ checkpoint_path=args.weights,
276
+ data_path=args.data_path,
277
+ dataset=args.dataset,
278
+ backbone=args.backbone,
279
+ aux=args.aux,
280
+ num_features=256,
281
+ aux_weight=0,
282
+ se_loss=False,
283
+ se_weight=0,
284
+ base_lr=0,
285
+ batch_size=1,
286
+ max_epochs=0,
287
+ ignore_index=args.ignore_index,
288
+ dropout=0.0,
289
+ scale_inv=args.scale_inv,
290
+ augment=False,
291
+ no_batchnorm=False,
292
+ widehead=args.widehead,
293
+ widehead_hr=args.widehead_hr,
294
+ map_locatin="cpu",
295
+ arch_option=0,
296
+ block_depth=0,
297
+ activation='lrelu',
298
+ )
299
+
300
+ input_transform = module.val_transform
301
+
302
+ # dataloader
303
+ loader_kwargs = (
304
+ {"num_workers": args.workers, "pin_memory": True} if args.cuda else {}
305
+ )
306
+
307
+ # model
308
+ if isinstance(module.net, BaseNet):
309
+ model = module.net
310
+ else:
311
+ model = module
312
+
313
+ model = model.eval()
314
+ model = model.cpu()
315
+ scales = (
316
+ [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]
317
+ if args.dataset == "citys"
318
+ else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
319
+ )
320
+
321
+ model.mean = [0.5, 0.5, 0.5]
322
+ model.std = [0.5, 0.5, 0.5]
323
+ evaluator = LSeg_MultiEvalModule(
324
+ model, scales=scales, flip=True
325
+ ).cuda()
326
+ evaluator.eval()
327
+
328
+ transform = transforms.Compose(
329
+ [
330
+ transforms.ToTensor(),
331
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
332
+ transforms.Resize([360,480]),
333
+ ]
334
+ )
335
+
336
+ return evaluator, transform
337
+
338
+ """
339
+ # LSeg Demo
340
+ """
341
+ lseg_model, lseg_transform = load_model()
342
+ uploaded_file = st.file_uploader("Choose an image...")
343
+ input_labels = st.text_input("Input labels", value="dog, grass, other")
344
+ st.write("The labels are", input_labels)
345
+
346
+ if uploaded_file is not None:
347
+ image = Image.open(uploaded_file)
348
+ pimage = lseg_transform(np.array(image)).unsqueeze(0)
349
+
350
+ labels = []
351
+ for label in input_labels.split(","):
352
+ labels.append(label.strip())
353
+
354
+ with torch.no_grad():
355
+ outputs = lseg_model.parallel_forward(pimage, labels)
356
+
357
+ predicts = [
358
+ torch.max(output, 1)[1].cpu().numpy()
359
+ for output in outputs
360
+ ]
361
+
362
+ image = pimage[0].permute(1,2,0)
363
+ image = image * 0.5 + 0.5
364
+ image = Image.fromarray(np.uint8(255*image)).convert("RGBA")
365
+
366
+ pred = predicts[0]
367
+ new_palette = get_new_pallete(len(labels))
368
+ mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels)
369
+ seg = mask.convert("RGBA")
370
+
371
+ fig = plt.figure()
372
+ plt.subplot(121)
373
+ plt.imshow(image)
374
+ plt.axis('off')
375
+
376
+ plt.subplot(122)
377
+ plt.imshow(seg)
378
+ plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5})
379
+ plt.axis('off')
380
+
381
+ plt.tight_layout()
382
+
383
+ #st.image([image,seg], width=700, caption=["Input image", "Segmentation"])
384
+ st.pyplot(fig)
385
+
386
+
modules/lseg_module.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from argparse import ArgumentParser
6
+ import pytorch_lightning as pl
7
+ from .lsegmentation_module import LSegmentationModule
8
+ from .models.lseg_net import LSegNet
9
+ from encoding.models.sseg.base import up_kwargs
10
+
11
+ import os
12
+ import clip
13
+ import numpy as np
14
+
15
+ from scipy import signal
16
+ import glob
17
+
18
+ from PIL import Image
19
+ import matplotlib.pyplot as plt
20
+ import pandas as pd
21
+
22
+
23
+ class LSegModule(LSegmentationModule):
24
+ def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
25
+ super(LSegModule, self).__init__(
26
+ data_path, dataset, batch_size, base_lr, max_epochs, **kwargs
27
+ )
28
+
29
+ if dataset == "citys":
30
+ self.base_size = 2048
31
+ self.crop_size = 768
32
+ else:
33
+ self.base_size = 520
34
+ self.crop_size = 480
35
+
36
+ use_pretrained = True
37
+ norm_mean= [0.5, 0.5, 0.5]
38
+ norm_std = [0.5, 0.5, 0.5]
39
+
40
+ print('** Use norm {}, {} as the mean and std **'.format(norm_mean, norm_std))
41
+
42
+ train_transform = [
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(norm_mean, norm_std),
45
+ ]
46
+
47
+ val_transform = [
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(norm_mean, norm_std),
50
+ ]
51
+
52
+ self.train_transform = transforms.Compose(train_transform)
53
+ self.val_transform = transforms.Compose(val_transform)
54
+
55
+ self.trainset = self.get_trainset(
56
+ dataset,
57
+ augment=kwargs["augment"],
58
+ base_size=self.base_size,
59
+ crop_size=self.crop_size,
60
+ )
61
+
62
+ self.valset = self.get_valset(
63
+ dataset,
64
+ augment=kwargs["augment"],
65
+ base_size=self.base_size,
66
+ crop_size=self.crop_size,
67
+ )
68
+
69
+ use_batchnorm = (
70
+ (not kwargs["no_batchnorm"]) if "no_batchnorm" in kwargs else True
71
+ )
72
+ # print(kwargs)
73
+
74
+ labels = self.get_labels('ade20k')
75
+
76
+ self.net = LSegNet(
77
+ labels=labels,
78
+ backbone=kwargs["backbone"],
79
+ features=kwargs["num_features"],
80
+ crop_size=self.crop_size,
81
+ arch_option=kwargs["arch_option"],
82
+ block_depth=kwargs["block_depth"],
83
+ activation=kwargs["activation"],
84
+ )
85
+
86
+ self.net.pretrained.model.patch_embed.img_size = (
87
+ self.crop_size,
88
+ self.crop_size,
89
+ )
90
+
91
+ self._up_kwargs = up_kwargs
92
+ self.mean = norm_mean
93
+ self.std = norm_std
94
+
95
+ self.criterion = self.get_criterion(**kwargs)
96
+
97
+ def get_labels(self, dataset):
98
+ labels = []
99
+ path = 'label_files/{}_objectInfo150.txt'.format(dataset)
100
+ assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path)
101
+ f = open(path, 'r')
102
+ lines = f.readlines()
103
+ for line in lines:
104
+ label = line.strip().split(',')[-1].split(';')[0]
105
+ labels.append(label)
106
+ f.close()
107
+ if dataset in ['ade20k']:
108
+ labels = labels[1:]
109
+ return labels
110
+
111
+
112
+ @staticmethod
113
+ def add_model_specific_args(parent_parser):
114
+ parser = LSegmentationModule.add_model_specific_args(parent_parser)
115
+ parser = ArgumentParser(parents=[parser])
116
+
117
+ parser.add_argument(
118
+ "--backbone",
119
+ type=str,
120
+ default="clip_vitl16_384",
121
+ help="backbone network",
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--num_features",
126
+ type=int,
127
+ default=256,
128
+ help="number of featurs that go from encoder to decoder",
129
+ )
130
+
131
+ parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")
132
+
133
+ parser.add_argument(
134
+ "--finetune_weights", type=str, help="load weights to finetune from"
135
+ )
136
+
137
+ parser.add_argument(
138
+ "--no-scaleinv",
139
+ default=True,
140
+ action="store_false",
141
+ help="turn off scaleinv layers",
142
+ )
143
+
144
+ parser.add_argument(
145
+ "--no-batchnorm",
146
+ default=False,
147
+ action="store_true",
148
+ help="turn off batchnorm",
149
+ )
150
+
151
+ parser.add_argument(
152
+ "--widehead", default=False, action="store_true", help="wider output head"
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--widehead_hr",
157
+ default=False,
158
+ action="store_true",
159
+ help="wider output head",
160
+ )
161
+
162
+ parser.add_argument(
163
+ "--arch_option",
164
+ type=int,
165
+ default=0,
166
+ help="which kind of architecture to be used",
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--block_depth",
171
+ type=int,
172
+ default=0,
173
+ help="how many blocks should be used",
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--activation",
178
+ choices=['lrelu', 'tanh'],
179
+ default="lrelu",
180
+ help="use which activation to activate the block",
181
+ )
182
+
183
+ return parser
modules/lsegmentation_module.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import time
3
+ import random
4
+ import clip
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
+
9
+ from argparse import ArgumentParser
10
+
11
+ import pytorch_lightning as pl
12
+
13
+ from data import get_dataset, get_available_datasets
14
+
15
+ from encoding.models import get_segmentation_model
16
+ from encoding.nn import SegmentationLosses
17
+
18
+ from encoding.utils import batch_pix_accuracy, batch_intersection_union
19
+
20
+ # add mixed precision
21
+ import torch.cuda.amp as amp
22
+ import numpy as np
23
+
24
+ from encoding.utils import SegmentationMetric
25
+
26
+ class LSegmentationModule(pl.LightningModule):
27
+ def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
28
+ super().__init__()
29
+
30
+ self.data_path = data_path
31
+ self.batch_size = batch_size
32
+ self.base_lr = base_lr / 16 * batch_size
33
+ self.lr = self.base_lr
34
+
35
+ self.epochs = max_epochs
36
+ self.other_kwargs = kwargs
37
+ self.enabled = False #True mixed precision will make things complicated and leading to NAN error
38
+ self.scaler = amp.GradScaler(enabled=self.enabled)
39
+
40
+ def forward(self, x):
41
+ return self.net(x)
42
+
43
+ def evaluate(self, x, target=None):
44
+ pred = self.net.forward(x)
45
+ if isinstance(pred, (tuple, list)):
46
+ pred = pred[0]
47
+ if target is None:
48
+ return pred
49
+ correct, labeled = batch_pix_accuracy(pred.data, target.data)
50
+ inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
51
+
52
+ return correct, labeled, inter, union
53
+
54
+ def evaluate_random(self, x, labelset, target=None):
55
+ pred = self.net.forward(x, labelset)
56
+ if isinstance(pred, (tuple, list)):
57
+ pred = pred[0]
58
+ if target is None:
59
+ return pred
60
+ correct, labeled = batch_pix_accuracy(pred.data, target.data)
61
+ inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
62
+
63
+ return correct, labeled, inter, union
64
+
65
+
66
+ def training_step(self, batch, batch_nb):
67
+ img, target = batch
68
+ with amp.autocast(enabled=self.enabled):
69
+ out = self(img)
70
+ multi_loss = isinstance(out, tuple)
71
+ if multi_loss:
72
+ loss = self.criterion(*out, target)
73
+ else:
74
+ loss = self.criterion(out, target)
75
+ loss = self.scaler.scale(loss)
76
+ final_output = out[0] if multi_loss else out
77
+ train_pred, train_gt = self._filter_invalid(final_output, target)
78
+ if train_gt.nelement() != 0:
79
+ self.train_accuracy(train_pred, train_gt)
80
+ self.log("train_loss", loss)
81
+ return loss
82
+
83
+ def training_epoch_end(self, outs):
84
+ self.log("train_acc_epoch", self.train_accuracy.compute())
85
+
86
+ def validation_step(self, batch, batch_nb):
87
+ img, target = batch
88
+ out = self(img)
89
+ multi_loss = isinstance(out, tuple)
90
+ if multi_loss:
91
+ val_loss = self.criterion(*out, target)
92
+ else:
93
+ val_loss = self.criterion(out, target)
94
+ final_output = out[0] if multi_loss else out
95
+ valid_pred, valid_gt = self._filter_invalid(final_output, target)
96
+ self.val_iou.update(target, final_output)
97
+ pixAcc, iou = self.val_iou.get()
98
+ self.log("val_loss_step", val_loss)
99
+ self.log("pix_acc_step", pixAcc)
100
+ self.log(
101
+ "val_acc_step",
102
+ self.val_accuracy(valid_pred, valid_gt),
103
+ )
104
+ self.log("val_iou", iou)
105
+
106
+ def validation_epoch_end(self, outs):
107
+ pixAcc, iou = self.val_iou.get()
108
+ self.log("val_acc_epoch", self.val_accuracy.compute())
109
+ self.log("val_iou_epoch", iou)
110
+ self.log("pix_acc_epoch", pixAcc)
111
+
112
+ self.val_iou.reset()
113
+
114
+ def _filter_invalid(self, pred, target):
115
+ valid = target != self.other_kwargs["ignore_index"]
116
+ _, mx = torch.max(pred, dim=1)
117
+ return mx[valid], target[valid]
118
+
119
+ def configure_optimizers(self):
120
+ params_list = [
121
+ {"params": self.net.pretrained.parameters(), "lr": self.base_lr},
122
+ ]
123
+ if hasattr(self.net, "scratch"):
124
+ print("Found output scratch")
125
+ params_list.append(
126
+ {"params": self.net.scratch.parameters(), "lr": self.base_lr * 10}
127
+ )
128
+ if hasattr(self.net, "auxlayer"):
129
+ print("Found auxlayer")
130
+ params_list.append(
131
+ {"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10}
132
+ )
133
+ if hasattr(self.net, "scale_inv_conv"):
134
+ print(self.net.scale_inv_conv)
135
+ print("Found scaleinv layers")
136
+ params_list.append(
137
+ {
138
+ "params": self.net.scale_inv_conv.parameters(),
139
+ "lr": self.base_lr * 10,
140
+ }
141
+ )
142
+ params_list.append(
143
+ {"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10}
144
+ )
145
+ params_list.append(
146
+ {"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10}
147
+ )
148
+ params_list.append(
149
+ {"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10}
150
+ )
151
+
152
+ if self.other_kwargs["midasproto"]:
153
+ print("Using midas optimization protocol")
154
+
155
+ opt = torch.optim.Adam(
156
+ params_list,
157
+ lr=self.base_lr,
158
+ betas=(0.9, 0.999),
159
+ weight_decay=self.other_kwargs["weight_decay"],
160
+ )
161
+ sch = torch.optim.lr_scheduler.LambdaLR(
162
+ opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
163
+ )
164
+
165
+ else:
166
+ opt = torch.optim.SGD(
167
+ params_list,
168
+ lr=self.base_lr,
169
+ momentum=0.9,
170
+ weight_decay=self.other_kwargs["weight_decay"],
171
+ )
172
+ sch = torch.optim.lr_scheduler.LambdaLR(
173
+ opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
174
+ )
175
+ return [opt], [sch]
176
+
177
+ def train_dataloader(self):
178
+ return torch.utils.data.DataLoader(
179
+ self.trainset,
180
+ batch_size=self.batch_size,
181
+ shuffle=True,
182
+ num_workers=16,
183
+ worker_init_fn=lambda x: random.seed(time.time() + x),
184
+ )
185
+
186
+ def val_dataloader(self):
187
+ return torch.utils.data.DataLoader(
188
+ self.valset,
189
+ batch_size=self.batch_size,
190
+ shuffle=False,
191
+ num_workers=16,
192
+ )
193
+
194
+ def get_trainset(self, dset, augment=False, **kwargs):
195
+ print(kwargs)
196
+ if augment == True:
197
+ mode = "train_x"
198
+ else:
199
+ mode = "train"
200
+
201
+ print(mode)
202
+ dset = get_dataset(
203
+ dset,
204
+ root=self.data_path,
205
+ split="train",
206
+ mode=mode,
207
+ transform=self.train_transform,
208
+ **kwargs
209
+ )
210
+
211
+ self.num_classes = dset.num_class
212
+ self.train_accuracy = pl.metrics.Accuracy()
213
+
214
+ return dset
215
+
216
+ def get_valset(self, dset, augment=False, **kwargs):
217
+ self.val_accuracy = pl.metrics.Accuracy()
218
+ self.val_iou = SegmentationMetric(self.num_classes)
219
+
220
+ if augment == True:
221
+ mode = "val_x"
222
+ else:
223
+ mode = "val"
224
+
225
+ print(mode)
226
+ return get_dataset(
227
+ dset,
228
+ root=self.data_path,
229
+ split="val",
230
+ mode=mode,
231
+ transform=self.val_transform,
232
+ **kwargs
233
+ )
234
+
235
+
236
+ def get_criterion(self, **kwargs):
237
+ return SegmentationLosses(
238
+ se_loss=kwargs["se_loss"],
239
+ aux=kwargs["aux"],
240
+ nclass=self.num_classes,
241
+ se_weight=kwargs["se_weight"],
242
+ aux_weight=kwargs["aux_weight"],
243
+ ignore_index=kwargs["ignore_index"],
244
+ )
245
+
246
+ @staticmethod
247
+ def add_model_specific_args(parent_parser):
248
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
249
+ parser.add_argument(
250
+ "--data_path", type=str, help="path where dataset is stored"
251
+ )
252
+ parser.add_argument(
253
+ "--dataset",
254
+ choices=get_available_datasets(),
255
+ default="ade20k",
256
+ help="dataset to train on",
257
+ )
258
+ parser.add_argument(
259
+ "--batch_size", type=int, default=16, help="size of the batches"
260
+ )
261
+ parser.add_argument(
262
+ "--base_lr", type=float, default=0.004, help="learning rate"
263
+ )
264
+ parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum")
265
+ parser.add_argument(
266
+ "--weight_decay", type=float, default=1e-4, help="weight_decay"
267
+ )
268
+ parser.add_argument(
269
+ "--aux", action="store_true", default=False, help="Auxilary Loss"
270
+ )
271
+ parser.add_argument(
272
+ "--aux-weight",
273
+ type=float,
274
+ default=0.2,
275
+ help="Auxilary loss weight (default: 0.2)",
276
+ )
277
+ parser.add_argument(
278
+ "--se-loss",
279
+ action="store_true",
280
+ default=False,
281
+ help="Semantic Encoding Loss SE-loss",
282
+ )
283
+ parser.add_argument(
284
+ "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
285
+ )
286
+
287
+ parser.add_argument(
288
+ "--midasproto", action="store_true", default=False, help="midasprotocol"
289
+ )
290
+
291
+ parser.add_argument(
292
+ "--ignore_index",
293
+ type=int,
294
+ default=-1,
295
+ help="numeric value of ignore label in gt",
296
+ )
297
+ parser.add_argument(
298
+ "--augment",
299
+ action="store_true",
300
+ default=False,
301
+ help="Use extended augmentations",
302
+ )
303
+
304
+ return parser
modules/models/lseg_blocks.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .lseg_vit import (
5
+ _make_pretrained_clip_vitl16_384,
6
+ _make_pretrained_clip_vitb32_384,
7
+ _make_pretrained_clipRN50x16_vitl16_384,
8
+ forward_vit,
9
+ )
10
+
11
+
12
+ def _make_encoder(
13
+ backbone,
14
+ features,
15
+ use_pretrained=True,
16
+ groups=1,
17
+ expand=False,
18
+ exportable=True,
19
+ hooks=None,
20
+ use_vit_only=False,
21
+ use_readout="ignore",
22
+ enable_attention_hooks=False,
23
+ ):
24
+ if backbone == "clip_vitl16_384":
25
+ clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384(
26
+ use_pretrained,
27
+ hooks=hooks,
28
+ use_readout=use_readout,
29
+ enable_attention_hooks=enable_attention_hooks,
30
+ )
31
+ scratch = _make_scratch(
32
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
33
+ )
34
+ elif backbone == "clipRN50x16_vitl16_384":
35
+ clip_pretrained, pretrained = _make_pretrained_clipRN50x16_vitl16_384(
36
+ use_pretrained,
37
+ hooks=hooks,
38
+ use_readout=use_readout,
39
+ enable_attention_hooks=enable_attention_hooks,
40
+ )
41
+ scratch = _make_scratch(
42
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
43
+ )
44
+ elif backbone == "clip_vitb32_384":
45
+ clip_pretrained, pretrained = _make_pretrained_clip_vitb32_384(
46
+ use_pretrained,
47
+ hooks=hooks,
48
+ use_readout=use_readout,
49
+ )
50
+ scratch = _make_scratch(
51
+ [96, 192, 384, 768], features, groups=groups, expand=expand
52
+ )
53
+ else:
54
+ print(f"Backbone '{backbone}' not implemented")
55
+ assert False
56
+
57
+ return clip_pretrained, pretrained, scratch
58
+
59
+
60
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
61
+ scratch = nn.Module()
62
+
63
+ out_shape1 = out_shape
64
+ out_shape2 = out_shape
65
+ out_shape3 = out_shape
66
+ out_shape4 = out_shape
67
+ if expand == True:
68
+ out_shape1 = out_shape
69
+ out_shape2 = out_shape * 2
70
+ out_shape3 = out_shape * 4
71
+ out_shape4 = out_shape * 8
72
+
73
+ scratch.layer1_rn = nn.Conv2d(
74
+ in_shape[0],
75
+ out_shape1,
76
+ kernel_size=3,
77
+ stride=1,
78
+ padding=1,
79
+ bias=False,
80
+ groups=groups,
81
+ )
82
+ scratch.layer2_rn = nn.Conv2d(
83
+ in_shape[1],
84
+ out_shape2,
85
+ kernel_size=3,
86
+ stride=1,
87
+ padding=1,
88
+ bias=False,
89
+ groups=groups,
90
+ )
91
+ scratch.layer3_rn = nn.Conv2d(
92
+ in_shape[2],
93
+ out_shape3,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=False,
98
+ groups=groups,
99
+ )
100
+ scratch.layer4_rn = nn.Conv2d(
101
+ in_shape[3],
102
+ out_shape4,
103
+ kernel_size=3,
104
+ stride=1,
105
+ padding=1,
106
+ bias=False,
107
+ groups=groups,
108
+ )
109
+
110
+ return scratch
111
+
112
+
113
+ class Interpolate(nn.Module):
114
+ """Interpolation module."""
115
+
116
+ def __init__(self, scale_factor, mode, align_corners=False):
117
+ """Init.
118
+
119
+ Args:
120
+ scale_factor (float): scaling
121
+ mode (str): interpolation mode
122
+ """
123
+ super(Interpolate, self).__init__()
124
+
125
+ self.interp = nn.functional.interpolate
126
+ self.scale_factor = scale_factor
127
+ self.mode = mode
128
+ self.align_corners = align_corners
129
+
130
+ def forward(self, x):
131
+ """Forward pass.
132
+
133
+ Args:
134
+ x (tensor): input
135
+
136
+ Returns:
137
+ tensor: interpolated data
138
+ """
139
+
140
+ x = self.interp(
141
+ x,
142
+ scale_factor=self.scale_factor,
143
+ mode=self.mode,
144
+ align_corners=self.align_corners,
145
+ )
146
+
147
+ return x
148
+
149
+
150
+ class ResidualConvUnit(nn.Module):
151
+ """Residual convolution module."""
152
+
153
+ def __init__(self, features):
154
+ """Init.
155
+
156
+ Args:
157
+ features (int): number of features
158
+ """
159
+ super().__init__()
160
+
161
+ self.conv1 = nn.Conv2d(
162
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
163
+ )
164
+
165
+ self.conv2 = nn.Conv2d(
166
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
167
+ )
168
+
169
+ self.relu = nn.ReLU(inplace=True)
170
+
171
+ def forward(self, x):
172
+ """Forward pass.
173
+
174
+ Args:
175
+ x (tensor): input
176
+
177
+ Returns:
178
+ tensor: output
179
+ """
180
+ out = self.relu(x)
181
+ out = self.conv1(out)
182
+ out = self.relu(out)
183
+ out = self.conv2(out)
184
+
185
+ return out + x
186
+
187
+
188
+ class FeatureFusionBlock(nn.Module):
189
+ """Feature fusion block."""
190
+
191
+ def __init__(self, features):
192
+ """Init.
193
+
194
+ Args:
195
+ features (int): number of features
196
+ """
197
+ super(FeatureFusionBlock, self).__init__()
198
+
199
+ self.resConfUnit1 = ResidualConvUnit(features)
200
+ self.resConfUnit2 = ResidualConvUnit(features)
201
+
202
+ def forward(self, *xs):
203
+ """Forward pass.
204
+
205
+ Returns:
206
+ tensor: output
207
+ """
208
+ output = xs[0]
209
+
210
+ if len(xs) == 2:
211
+ output += self.resConfUnit1(xs[1])
212
+
213
+ output = self.resConfUnit2(output)
214
+
215
+ output = nn.functional.interpolate(
216
+ output, scale_factor=2, mode="bilinear", align_corners=True
217
+ )
218
+
219
+ return output
220
+
221
+
222
+ class ResidualConvUnit_custom(nn.Module):
223
+ """Residual convolution module."""
224
+
225
+ def __init__(self, features, activation, bn):
226
+ """Init.
227
+
228
+ Args:
229
+ features (int): number of features
230
+ """
231
+ super().__init__()
232
+
233
+ self.bn = bn
234
+
235
+ self.groups = 1
236
+
237
+ self.conv1 = nn.Conv2d(
238
+ features,
239
+ features,
240
+ kernel_size=3,
241
+ stride=1,
242
+ padding=1,
243
+ bias=not self.bn,
244
+ groups=self.groups,
245
+ )
246
+
247
+ self.conv2 = nn.Conv2d(
248
+ features,
249
+ features,
250
+ kernel_size=3,
251
+ stride=1,
252
+ padding=1,
253
+ bias=not self.bn,
254
+ groups=self.groups,
255
+ )
256
+
257
+ if self.bn == True:
258
+ self.bn1 = nn.BatchNorm2d(features)
259
+ self.bn2 = nn.BatchNorm2d(features)
260
+
261
+ self.activation = activation
262
+
263
+ self.skip_add = nn.quantized.FloatFunctional()
264
+
265
+ def forward(self, x):
266
+ """Forward pass.
267
+
268
+ Args:
269
+ x (tensor): input
270
+
271
+ Returns:
272
+ tensor: output
273
+ """
274
+
275
+ out = self.activation(x)
276
+ out = self.conv1(out)
277
+ if self.bn == True:
278
+ out = self.bn1(out)
279
+
280
+ out = self.activation(out)
281
+ out = self.conv2(out)
282
+ if self.bn == True:
283
+ out = self.bn2(out)
284
+
285
+ if self.groups > 1:
286
+ out = self.conv_merge(out)
287
+
288
+ return self.skip_add.add(out, x)
289
+
290
+ # return out + x
291
+
292
+
293
+ class FeatureFusionBlock_custom(nn.Module):
294
+ """Feature fusion block."""
295
+
296
+ def __init__(
297
+ self,
298
+ features,
299
+ activation,
300
+ deconv=False,
301
+ bn=False,
302
+ expand=False,
303
+ align_corners=True,
304
+ ):
305
+ """Init.
306
+
307
+ Args:
308
+ features (int): number of features
309
+ """
310
+ super(FeatureFusionBlock_custom, self).__init__()
311
+
312
+ self.deconv = deconv
313
+ self.align_corners = align_corners
314
+
315
+ self.groups = 1
316
+
317
+ self.expand = expand
318
+ out_features = features
319
+ if self.expand == True:
320
+ out_features = features // 2
321
+
322
+ self.out_conv = nn.Conv2d(
323
+ features,
324
+ out_features,
325
+ kernel_size=1,
326
+ stride=1,
327
+ padding=0,
328
+ bias=True,
329
+ groups=1,
330
+ )
331
+
332
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
333
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
334
+
335
+ self.skip_add = nn.quantized.FloatFunctional()
336
+
337
+ def forward(self, *xs):
338
+ """Forward pass.
339
+
340
+ Returns:
341
+ tensor: output
342
+ """
343
+ output = xs[0]
344
+
345
+ if len(xs) == 2:
346
+ res = self.resConfUnit1(xs[1])
347
+ output = self.skip_add.add(output, res)
348
+ # output += res
349
+
350
+ output = self.resConfUnit2(output)
351
+
352
+ output = nn.functional.interpolate(
353
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
354
+ )
355
+
356
+ output = self.out_conv(output)
357
+
358
+ return output
359
+
modules/models/lseg_net.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import types
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .lseg_blocks import FeatureFusionBlock, Interpolate, _make_encoder, FeatureFusionBlock_custom, forward_vit
9
+ import clip
10
+ import numpy as np
11
+ import pandas as pd
12
+ import os
13
+
14
+ class depthwise_clipseg_conv(nn.Module):
15
+ def __init__(self):
16
+ super(depthwise_clipseg_conv, self).__init__()
17
+ self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1)
18
+
19
+ def depthwise_clipseg(self, x, channels):
20
+ x = torch.cat([self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], dim=1)
21
+ return x
22
+
23
+ def forward(self, x):
24
+ channels = x.shape[1]
25
+ out = self.depthwise_clipseg(x, channels)
26
+ return out
27
+
28
+
29
+ class depthwise_conv(nn.Module):
30
+ def __init__(self, kernel_size=3, stride=1, padding=1):
31
+ super(depthwise_conv, self).__init__()
32
+ self.depthwise = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding)
33
+
34
+ def forward(self, x):
35
+ # support for 4D tensor with NCHW
36
+ C, H, W = x.shape[1:]
37
+ x = x.reshape(-1, 1, H, W)
38
+ x = self.depthwise(x)
39
+ x = x.view(-1, C, H, W)
40
+ return x
41
+
42
+
43
+ class depthwise_block(nn.Module):
44
+ def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
45
+ super(depthwise_block, self).__init__()
46
+ self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
47
+ if activation == 'relu':
48
+ self.activation = nn.ReLU()
49
+ elif activation == 'lrelu':
50
+ self.activation = nn.LeakyReLU()
51
+ elif activation == 'tanh':
52
+ self.activation = nn.Tanh()
53
+
54
+ def forward(self, x, act=True):
55
+ x = self.depthwise(x)
56
+ if act:
57
+ x = self.activation(x)
58
+ return x
59
+
60
+
61
+ class bottleneck_block(nn.Module):
62
+ def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
63
+ super(bottleneck_block, self).__init__()
64
+ self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
65
+ if activation == 'relu':
66
+ self.activation = nn.ReLU()
67
+ elif activation == 'lrelu':
68
+ self.activation = nn.LeakyReLU()
69
+ elif activation == 'tanh':
70
+ self.activation = nn.Tanh()
71
+
72
+
73
+ def forward(self, x, act=True):
74
+ sum_layer = x.max(dim=1, keepdim=True)[0]
75
+ x = self.depthwise(x)
76
+ x = x + sum_layer
77
+ if act:
78
+ x = self.activation(x)
79
+ return x
80
+
81
+ class BaseModel(torch.nn.Module):
82
+ def load(self, path):
83
+ """Load model from file.
84
+ Args:
85
+ path (str): file path
86
+ """
87
+ parameters = torch.load(path, map_location=torch.device("cpu"))
88
+
89
+ if "optimizer" in parameters:
90
+ parameters = parameters["model"]
91
+
92
+ self.load_state_dict(parameters)
93
+
94
+ def _make_fusion_block(features, use_bn):
95
+ return FeatureFusionBlock_custom(
96
+ features,
97
+ activation=nn.ReLU(False),
98
+ deconv=False,
99
+ bn=use_bn,
100
+ expand=False,
101
+ align_corners=True,
102
+ )
103
+
104
+ class LSeg(BaseModel):
105
+ def __init__(
106
+ self,
107
+ head,
108
+ features=256,
109
+ backbone="clip_vitl16_384",
110
+ readout="project",
111
+ channels_last=False,
112
+ use_bn=False,
113
+ **kwargs,
114
+ ):
115
+ super(LSeg, self).__init__()
116
+
117
+ self.channels_last = channels_last
118
+
119
+ hooks = {
120
+ "clip_vitl16_384": [5, 11, 17, 23],
121
+ "clipRN50x16_vitl16_384": [5, 11, 17, 23],
122
+ "clip_vitb32_384": [2, 5, 8, 11],
123
+ }
124
+
125
+ # Instantiate backbone and reassemble blocks
126
+ self.clip_pretrained, self.pretrained, self.scratch = _make_encoder(
127
+ backbone,
128
+ features,
129
+ groups=1,
130
+ expand=False,
131
+ exportable=False,
132
+ hooks=hooks[backbone],
133
+ use_readout=readout,
134
+ )
135
+
136
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
137
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
138
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
139
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
140
+
141
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp()
142
+ if backbone in ["clipRN50x16_vitl16_384"]:
143
+ self.out_c = 768
144
+ else:
145
+ self.out_c = 512
146
+ self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1)
147
+
148
+ self.arch_option = kwargs["arch_option"]
149
+ if self.arch_option == 1:
150
+ self.scratch.head_block = bottleneck_block(activation=kwargs["activation"])
151
+ self.block_depth = kwargs['block_depth']
152
+ elif self.arch_option == 2:
153
+ self.scratch.head_block = depthwise_block(activation=kwargs["activation"])
154
+ self.block_depth = kwargs['block_depth']
155
+
156
+ self.scratch.output_conv = head
157
+
158
+ self.text = clip.tokenize(self.labels)
159
+
160
+ def forward(self, x, labelset=''):
161
+ if labelset == '':
162
+ text = self.text
163
+ else:
164
+ text = clip.tokenize(labelset)
165
+
166
+ if self.channels_last == True:
167
+ x.contiguous(memory_format=torch.channels_last)
168
+
169
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
170
+
171
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
172
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
173
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
174
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
175
+
176
+ path_4 = self.scratch.refinenet4(layer_4_rn)
177
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
178
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
179
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
180
+
181
+ text = text.to(x.device)
182
+ self.logit_scale = self.logit_scale.to(x.device)
183
+ text_features = self.clip_pretrained.encode_text(text)
184
+
185
+ image_features = self.scratch.head1(path_1)
186
+
187
+ imshape = image_features.shape
188
+ image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c)
189
+
190
+ # normalized features
191
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
192
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
193
+
194
+ logits_per_image = self.logit_scale * image_features.half() @ text_features.t()
195
+
196
+ out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2)
197
+
198
+ if self.arch_option in [1, 2]:
199
+ for _ in range(self.block_depth - 1):
200
+ out = self.scratch.head_block(out)
201
+ out = self.scratch.head_block(out, False)
202
+
203
+ out = self.scratch.output_conv(out)
204
+
205
+ return out
206
+
207
+
208
+ class LSegNet(LSeg):
209
+ """Network for semantic segmentation."""
210
+ def __init__(self, labels, path=None, scale_factor=0.5, crop_size=480, **kwargs):
211
+
212
+ features = kwargs["features"] if "features" in kwargs else 256
213
+ kwargs["use_bn"] = True
214
+
215
+ self.crop_size = crop_size
216
+ self.scale_factor = scale_factor
217
+ self.labels = labels
218
+
219
+ head = nn.Sequential(
220
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
221
+ )
222
+
223
+ super().__init__(head, **kwargs)
224
+
225
+ if path is not None:
226
+ self.load(path)
227
+
228
+
229
+
230
+
231
+
modules/models/lseg_vit.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+ import clip
8
+
9
+ activations = {}
10
+
11
+
12
+ def get_activation(name):
13
+ def hook(model, input, output):
14
+ activations[name] = output
15
+
16
+ return hook
17
+
18
+
19
+ attention = {}
20
+
21
+
22
+ def get_attention(name):
23
+ def hook(module, input, output):
24
+ x = input[0]
25
+ B, N, C = x.shape
26
+ qkv = (
27
+ module.qkv(x)
28
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29
+ .permute(2, 0, 3, 1, 4)
30
+ )
31
+ q, k, v = (
32
+ qkv[0],
33
+ qkv[1],
34
+ qkv[2],
35
+ ) # make torchscript happy (cannot use tensor as tuple)
36
+
37
+ attn = (q @ k.transpose(-2, -1)) * module.scale
38
+
39
+ attn = attn.softmax(dim=-1) # [:,:,1,1:]
40
+ attention[name] = attn
41
+
42
+ return hook
43
+
44
+
45
+ def get_mean_attention_map(attn, token, shape):
46
+ attn = attn[:, :, token, 1:]
47
+ attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48
+ attn = torch.nn.functional.interpolate(
49
+ attn, size=shape[2:], mode="bicubic", align_corners=False
50
+ ).squeeze(0)
51
+
52
+ all_attn = torch.mean(attn, 0)
53
+
54
+ return all_attn
55
+
56
+
57
+ class Slice(nn.Module):
58
+ def __init__(self, start_index=1):
59
+ super(Slice, self).__init__()
60
+ self.start_index = start_index
61
+
62
+ def forward(self, x):
63
+ return x[:, self.start_index :]
64
+
65
+
66
+ class AddReadout(nn.Module):
67
+ def __init__(self, start_index=1):
68
+ super(AddReadout, self).__init__()
69
+ self.start_index = start_index
70
+
71
+ def forward(self, x):
72
+ if self.start_index == 2:
73
+ readout = (x[:, 0] + x[:, 1]) / 2
74
+ else:
75
+ readout = x[:, 0]
76
+ return x[:, self.start_index :] + readout.unsqueeze(1)
77
+
78
+
79
+ class ProjectReadout(nn.Module):
80
+ def __init__(self, in_features, start_index=1):
81
+ super(ProjectReadout, self).__init__()
82
+ self.start_index = start_index
83
+
84
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85
+
86
+ def forward(self, x):
87
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88
+ features = torch.cat((x[:, self.start_index :], readout), -1)
89
+
90
+ return self.project(features)
91
+
92
+
93
+ class Transpose(nn.Module):
94
+ def __init__(self, dim0, dim1):
95
+ super(Transpose, self).__init__()
96
+ self.dim0 = dim0
97
+ self.dim1 = dim1
98
+
99
+ def forward(self, x):
100
+ x = x.transpose(self.dim0, self.dim1)
101
+ return x
102
+
103
+
104
+ def forward_vit(pretrained, x):
105
+ b, c, h, w = x.shape
106
+
107
+ # encoder
108
+ glob = pretrained.model.forward_flex(x)
109
+
110
+ layer_1 = pretrained.activations["1"]
111
+ layer_2 = pretrained.activations["2"]
112
+ layer_3 = pretrained.activations["3"]
113
+ layer_4 = pretrained.activations["4"]
114
+
115
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
116
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
117
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
118
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
119
+
120
+ unflatten = nn.Sequential(
121
+ nn.Unflatten(
122
+ 2,
123
+ torch.Size(
124
+ [
125
+ h // pretrained.model.patch_size[1],
126
+ w // pretrained.model.patch_size[0],
127
+ ]
128
+ ),
129
+ )
130
+ )
131
+
132
+ if layer_1.ndim == 3:
133
+ layer_1 = unflatten(layer_1)
134
+ if layer_2.ndim == 3:
135
+ layer_2 = unflatten(layer_2)
136
+ if layer_3.ndim == 3:
137
+ layer_3 = unflatten(layer_3)
138
+ if layer_4.ndim == 3:
139
+ layer_4 = unflatten(layer_4)
140
+
141
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
142
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
143
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
144
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
145
+
146
+ return layer_1, layer_2, layer_3, layer_4
147
+
148
+
149
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
150
+ posemb_tok, posemb_grid = (
151
+ posemb[:, : self.start_index],
152
+ posemb[0, self.start_index :],
153
+ )
154
+
155
+ gs_old = int(math.sqrt(len(posemb_grid)))
156
+
157
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
158
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
159
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
160
+
161
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
162
+
163
+ return posemb
164
+
165
+
166
+ def forward_flex(self, x):
167
+ b, c, h, w = x.shape
168
+
169
+ pos_embed = self._resize_pos_embed(
170
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
171
+ )
172
+
173
+ B = x.shape[0]
174
+
175
+ if hasattr(self.patch_embed, "backbone"):
176
+ x = self.patch_embed.backbone(x)
177
+ if isinstance(x, (list, tuple)):
178
+ x = x[-1] # last feature if backbone outputs list/tuple of features
179
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
180
+
181
+ if getattr(self, "dist_token", None) is not None:
182
+ cls_tokens = self.cls_token.expand(
183
+ B, -1, -1
184
+ ) # stole cls_tokens impl from Phil Wang, thanks
185
+ dist_token = self.dist_token.expand(B, -1, -1)
186
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
187
+ else:
188
+ cls_tokens = self.cls_token.expand(
189
+ B, -1, -1
190
+ ) # stole cls_tokens impl from Phil Wang, thanks
191
+ x = torch.cat((cls_tokens, x), dim=1)
192
+
193
+ x = x + pos_embed
194
+ x = self.pos_drop(x)
195
+
196
+ for blk in self.blocks:
197
+ x = blk(x)
198
+
199
+ x = self.norm(x)
200
+
201
+ return x
202
+
203
+
204
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
205
+ if use_readout == "ignore":
206
+ readout_oper = [Slice(start_index)] * len(features)
207
+ elif use_readout == "add":
208
+ readout_oper = [AddReadout(start_index)] * len(features)
209
+ elif use_readout == "project":
210
+ readout_oper = [
211
+ ProjectReadout(vit_features, start_index) for out_feat in features
212
+ ]
213
+ else:
214
+ assert (
215
+ False
216
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217
+
218
+ return readout_oper
219
+
220
+
221
+ def _make_pretrained_clip_vitl16_384(
222
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
223
+ ):
224
+ clip_pretrained, _ = clip.load("ViT-B/32", device='cuda', jit=False)
225
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
226
+
227
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
228
+
229
+ pretrained = _make_vit_b16_backbone(
230
+ model,
231
+ features=[256, 512, 1024, 1024],
232
+ hooks=hooks,
233
+ vit_features=1024,
234
+ use_readout=use_readout,
235
+ enable_attention_hooks=enable_attention_hooks,
236
+ )
237
+ return clip_pretrained, pretrained
238
+
239
+
240
+ def _make_pretrained_clipRN50x16_vitl16_384(
241
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
242
+ ):
243
+ clip_pretrained, _ = clip.load("RN50x16", device='cuda', jit=False)
244
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
245
+
246
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
247
+
248
+ pretrained = _make_vit_b16_backbone(
249
+ model,
250
+ features=[256, 512, 1024, 1024],
251
+ hooks=hooks,
252
+ vit_features=1024,
253
+ use_readout=use_readout,
254
+ enable_attention_hooks=enable_attention_hooks,
255
+ )
256
+ return clip_pretrained, pretrained
257
+
258
+
259
+ def _make_pretrained_clip_vitb32_384(pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False):
260
+ clip_pretrained, _ = clip.load("ViT-B/32", device='cuda', jit=False)
261
+ model = timm.create_model("vit_base_patch32_384", pretrained=pretrained)
262
+
263
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
264
+
265
+ pretrained = _make_vit_b32_backbone(
266
+ model,
267
+ features=[96, 192, 384, 768],
268
+ hooks=hooks,
269
+ use_readout=use_readout,
270
+ enable_attention_hooks=False,
271
+ )
272
+ return clip_pretrained, pretrained
273
+
274
+
275
+ def _make_vit_b32_backbone(
276
+ model,
277
+ features=[96, 192, 384, 768],
278
+ size=[384, 384],
279
+ hooks=[2, 5, 8, 11],
280
+ vit_features=768,
281
+ use_readout="ignore",
282
+ start_index=1,
283
+ enable_attention_hooks=False,
284
+ ):
285
+ pretrained = nn.Module()
286
+
287
+ pretrained.model = model
288
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
289
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
290
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
291
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
292
+
293
+ pretrained.activations = activations
294
+
295
+ pretrained.model.patch_size = [32, 32]
296
+ pretrained.model.start_index = start_index
297
+
298
+ if enable_attention_hooks:
299
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
300
+ get_attention("attn_1")
301
+ )
302
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
303
+ get_attention("attn_2")
304
+ )
305
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
306
+ get_attention("attn_3")
307
+ )
308
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
309
+ get_attention("attn_4")
310
+ )
311
+ pretrained.attention = attention
312
+
313
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
314
+
315
+ pretrained.act_postprocess1 = nn.Sequential(
316
+ readout_oper[0],
317
+ Transpose(1, 2),
318
+ nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
319
+ nn.Conv2d(
320
+ in_channels=vit_features,
321
+ out_channels=features[0],
322
+ kernel_size=1,
323
+ stride=1,
324
+ padding=0,
325
+ ),
326
+ nn.ConvTranspose2d(
327
+ in_channels=features[0],
328
+ out_channels=features[0],
329
+ kernel_size=8,
330
+ stride=8,
331
+ padding=0,
332
+ bias=True,
333
+ dilation=1,
334
+ groups=1,
335
+ ),
336
+ )
337
+
338
+ pretrained.act_postprocess2 = nn.Sequential(
339
+ readout_oper[1],
340
+ Transpose(1, 2),
341
+ nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
342
+ nn.Conv2d(
343
+ in_channels=vit_features,
344
+ out_channels=features[1],
345
+ kernel_size=1,
346
+ stride=1,
347
+ padding=0,
348
+ ),
349
+ nn.ConvTranspose2d(
350
+ in_channels=features[1],
351
+ out_channels=features[1],
352
+ kernel_size=4,
353
+ stride=4,
354
+ padding=0,
355
+ bias=True,
356
+ dilation=1,
357
+ groups=1,
358
+ ),
359
+ )
360
+
361
+ pretrained.act_postprocess3 = nn.Sequential(
362
+ readout_oper[2],
363
+ Transpose(1, 2),
364
+ nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
365
+ nn.Conv2d(
366
+ in_channels=vit_features,
367
+ out_channels=features[2],
368
+ kernel_size=1,
369
+ stride=1,
370
+ padding=0,
371
+ ),
372
+ nn.ConvTranspose2d(
373
+ in_channels=features[2],
374
+ out_channels=features[2],
375
+ kernel_size=2,
376
+ stride=2,
377
+ padding=0,
378
+ # output_padding=output_padding,
379
+ bias=True,
380
+ dilation=1,
381
+ groups=1,
382
+ ),
383
+ )
384
+
385
+ pretrained.act_postprocess4 = nn.Sequential(
386
+ readout_oper[3],
387
+ Transpose(1, 2),
388
+ nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
389
+ nn.Conv2d(
390
+ in_channels=vit_features,
391
+ out_channels=features[3],
392
+ kernel_size=1,
393
+ stride=1,
394
+ padding=0,
395
+ ),
396
+ )
397
+
398
+ # We inject this function into the VisionTransformer instances so that
399
+ # we can use it with interpolated position embeddings without modifying the library source.
400
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
401
+ pretrained.model._resize_pos_embed = types.MethodType(
402
+ _resize_pos_embed, pretrained.model
403
+ )
404
+
405
+ return pretrained
406
+
407
+
408
+ def _make_vit_b16_backbone(
409
+ model,
410
+ features=[96, 192, 384, 768],
411
+ size=[384, 384],
412
+ hooks=[2, 5, 8, 11],
413
+ vit_features=768,
414
+ use_readout="ignore",
415
+ start_index=1,
416
+ enable_attention_hooks=False,
417
+ ):
418
+ pretrained = nn.Module()
419
+
420
+ pretrained.model = model
421
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
422
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
423
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
424
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
425
+
426
+ pretrained.activations = activations
427
+
428
+ if enable_attention_hooks:
429
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
430
+ get_attention("attn_1")
431
+ )
432
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
433
+ get_attention("attn_2")
434
+ )
435
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
436
+ get_attention("attn_3")
437
+ )
438
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
439
+ get_attention("attn_4")
440
+ )
441
+ pretrained.attention = attention
442
+
443
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
444
+
445
+ # 32, 48, 136, 384
446
+ pretrained.act_postprocess1 = nn.Sequential(
447
+ readout_oper[0],
448
+ Transpose(1, 2),
449
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
450
+ nn.Conv2d(
451
+ in_channels=vit_features,
452
+ out_channels=features[0],
453
+ kernel_size=1,
454
+ stride=1,
455
+ padding=0,
456
+ ),
457
+ nn.ConvTranspose2d(
458
+ in_channels=features[0],
459
+ out_channels=features[0],
460
+ kernel_size=4,
461
+ stride=4,
462
+ padding=0,
463
+ bias=True,
464
+ dilation=1,
465
+ groups=1,
466
+ ),
467
+ )
468
+
469
+ pretrained.act_postprocess2 = nn.Sequential(
470
+ readout_oper[1],
471
+ Transpose(1, 2),
472
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
473
+ nn.Conv2d(
474
+ in_channels=vit_features,
475
+ out_channels=features[1],
476
+ kernel_size=1,
477
+ stride=1,
478
+ padding=0,
479
+ ),
480
+ nn.ConvTranspose2d(
481
+ in_channels=features[1],
482
+ out_channels=features[1],
483
+ kernel_size=2,
484
+ stride=2,
485
+ padding=0,
486
+ bias=True,
487
+ dilation=1,
488
+ groups=1,
489
+ ),
490
+ )
491
+
492
+ pretrained.act_postprocess3 = nn.Sequential(
493
+ readout_oper[2],
494
+ Transpose(1, 2),
495
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
496
+ nn.Conv2d(
497
+ in_channels=vit_features,
498
+ out_channels=features[2],
499
+ kernel_size=1,
500
+ stride=1,
501
+ padding=0,
502
+ ),
503
+ )
504
+
505
+ pretrained.act_postprocess4 = nn.Sequential(
506
+ readout_oper[3],
507
+ Transpose(1, 2),
508
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
509
+ nn.Conv2d(
510
+ in_channels=vit_features,
511
+ out_channels=features[3],
512
+ kernel_size=1,
513
+ stride=1,
514
+ padding=0,
515
+ ),
516
+ nn.Conv2d(
517
+ in_channels=features[3],
518
+ out_channels=features[3],
519
+ kernel_size=3,
520
+ stride=2,
521
+ padding=1,
522
+ ),
523
+ )
524
+
525
+ pretrained.model.start_index = start_index
526
+ pretrained.model.patch_size = [16, 16]
527
+
528
+ # We inject this function into the VisionTransformer instances so that
529
+ # we can use it with interpolated position embeddings without modifying the library source.
530
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
531
+ pretrained.model._resize_pos_embed = types.MethodType(
532
+ _resize_pos_embed, pretrained.model
533
+ )
534
+
535
+ return pretrained
prepare_ade20k.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +
2
+ # revised from https://github.com/zhanghang1989/PyTorch-Encoding/blob/331ecdd5306104614cb414b16fbcd9d1a8d40e1e/scripts/prepare_ade20k.py
3
+
4
+ """Prepare ADE20K dataset"""
5
+ import os
6
+ import shutil
7
+ import argparse
8
+ import zipfile
9
+ from encoding.utils import download, mkdir
10
+ # -
11
+
12
+ _TARGET_DIR = os.path.expanduser('../datasets/')
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(
16
+ description='Initialize ADE20K dataset.',
17
+ epilog='Example: python prepare_ade20k.py',
18
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19
+ parser.add_argument('--download-dir', default=None, help='dataset directory on disk')
20
+ args = parser.parse_args()
21
+ return args
22
+
23
+ def download_ade(path, overwrite=False):
24
+ _AUG_DOWNLOAD_URLS = [
25
+ ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
26
+ ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),]
27
+ download_dir = path
28
+ mkdir(download_dir)
29
+ for url, checksum in _AUG_DOWNLOAD_URLS:
30
+ filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
31
+ # extract
32
+ with zipfile.ZipFile(filename,"r") as zip_ref:
33
+ zip_ref.extractall(path=path)
34
+
35
+
36
+ if __name__ == '__main__':
37
+ args = parse_args()
38
+ mkdir(os.path.expanduser('../datasets/'))
39
+ if args.download_dir is not None:
40
+ if os.path.isdir(_TARGET_DIR):
41
+ os.remove(_TARGET_DIR)
42
+ # make symlink
43
+ os.symlink(args.download_dir, _TARGET_DIR)
44
+ else:
45
+ download_ade(_TARGET_DIR, overwrite=False)
test_lseg.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from collections import OrderedDict
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.utils import data
9
+ import torchvision.transforms as transform
10
+ from torch.nn.parallel.scatter_gather import gather
11
+ import encoding.utils as utils
12
+ from encoding.nn import SegmentationLosses, SyncBatchNorm
13
+ from encoding.parallel import DataParallelModel, DataParallelCriterion
14
+ from encoding.datasets import test_batchify_fn
15
+ from encoding.models.sseg import BaseNet
16
+ from modules.lseg_module import LSegModule
17
+ from utils import Resize
18
+ import cv2
19
+ import math
20
+ import types
21
+ import functools
22
+ import torchvision.transforms as torch_transforms
23
+ import copy
24
+ import itertools
25
+ from PIL import Image
26
+ import matplotlib.pyplot as plt
27
+ import clip
28
+ import matplotlib as mpl
29
+ import matplotlib.colors as mplc
30
+ import matplotlib.figure as mplfigure
31
+ import matplotlib.patches as mpatches
32
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
33
+ from data import get_dataset
34
+ from additional_utils.encoding_models import MultiEvalModule as LSeg_MultiEvalModule
35
+ import torchvision.transforms as transforms
36
+
37
+ class Options:
38
+ def __init__(self):
39
+ parser = argparse.ArgumentParser(description="PyTorch Segmentation")
40
+ # model and dataset
41
+ parser.add_argument(
42
+ "--model", type=str, default="encnet", help="model name (default: encnet)"
43
+ )
44
+ parser.add_argument(
45
+ "--backbone",
46
+ type=str,
47
+ default="clip_vitl16_384",
48
+ help="backbone name (default: resnet50)",
49
+ )
50
+ parser.add_argument(
51
+ "--dataset",
52
+ type=str,
53
+ default="ade20k",
54
+ help="dataset name (default: pascal12)",
55
+ )
56
+ parser.add_argument(
57
+ "--workers", type=int, default=16, metavar="N", help="dataloader threads"
58
+ )
59
+ parser.add_argument(
60
+ "--base-size", type=int, default=520, help="base image size"
61
+ )
62
+ parser.add_argument(
63
+ "--crop-size", type=int, default=480, help="crop image size"
64
+ )
65
+ parser.add_argument(
66
+ "--train-split",
67
+ type=str,
68
+ default="train",
69
+ help="dataset train split (default: train)",
70
+ )
71
+ # training hyper params
72
+ parser.add_argument(
73
+ "--aux", action="store_true", default=False, help="Auxilary Loss"
74
+ )
75
+ parser.add_argument(
76
+ "--se-loss",
77
+ action="store_true",
78
+ default=False,
79
+ help="Semantic Encoding Loss SE-loss",
80
+ )
81
+ parser.add_argument(
82
+ "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
83
+ )
84
+ parser.add_argument(
85
+ "--batch-size",
86
+ type=int,
87
+ default=16,
88
+ metavar="N",
89
+ help="input batch size for \
90
+ training (default: auto)",
91
+ )
92
+ parser.add_argument(
93
+ "--test-batch-size",
94
+ type=int,
95
+ default=16,
96
+ metavar="N",
97
+ help="input batch size for \
98
+ testing (default: same as batch size)",
99
+ )
100
+ # cuda, seed and logging
101
+ parser.add_argument(
102
+ "--no-cuda",
103
+ action="store_true",
104
+ default=False,
105
+ help="disables CUDA training",
106
+ )
107
+ parser.add_argument(
108
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
109
+ )
110
+ parser.add_argument(
111
+ "--weights", type=str, default=None, help="checkpoint to test"
112
+ )
113
+ parser.add_argument(
114
+ "--eval", action="store_true", default=False, help="evaluating mIoU"
115
+ )
116
+ parser.add_argument(
117
+ "--export",
118
+ type=str,
119
+ default=None,
120
+ help="put the path to resuming file if needed",
121
+ )
122
+
123
+ parser.add_argument(
124
+ "--acc-bn",
125
+ action="store_true",
126
+ default=False,
127
+ help="Re-accumulate BN statistics",
128
+ )
129
+ parser.add_argument(
130
+ "--test-val",
131
+ action="store_true",
132
+ default=False,
133
+ help="generate masks on val set",
134
+ )
135
+ parser.add_argument(
136
+ "--no-val",
137
+ action="store_true",
138
+ default=False,
139
+ help="skip validation during training",
140
+ )
141
+ parser.add_argument(
142
+ "--module",
143
+ default='lseg',
144
+ help="select model definition",
145
+ )
146
+ # test option
147
+ parser.add_argument(
148
+ "--data-path", type=str, default=None, help="path to test image folder"
149
+ )
150
+ parser.add_argument(
151
+ "--no-scaleinv",
152
+ dest="scale_inv",
153
+ default=True,
154
+ action="store_false",
155
+ help="turn off scaleinv layers",
156
+ )
157
+ parser.add_argument(
158
+ "--widehead", default=False, action="store_true", help="wider output head"
159
+ )
160
+ parser.add_argument(
161
+ "--widehead_hr",
162
+ default=False,
163
+ action="store_true",
164
+ help="wider output head",
165
+ )
166
+ parser.add_argument(
167
+ "--ignore_index",
168
+ type=int,
169
+ default=-1,
170
+ help="numeric value of ignore label in gt",
171
+ )
172
+ parser.add_argument(
173
+ "--label_src",
174
+ type=str,
175
+ default="default",
176
+ help="how to get the labels",
177
+ )
178
+ parser.add_argument(
179
+ "--jobname",
180
+ type=str,
181
+ default="default",
182
+ help="select which dataset",
183
+ )
184
+ parser.add_argument(
185
+ "--no-strict",
186
+ dest="strict",
187
+ default=True,
188
+ action="store_false",
189
+ help="no-strict copy the model",
190
+ )
191
+ parser.add_argument(
192
+ "--arch_option",
193
+ type=int,
194
+ default=0,
195
+ help="which kind of architecture to be used",
196
+ )
197
+ parser.add_argument(
198
+ "--block_depth",
199
+ type=int,
200
+ default=0,
201
+ help="how many blocks should be used",
202
+ )
203
+ parser.add_argument(
204
+ "--activation",
205
+ choices=['lrelu', 'tanh'],
206
+ default="lrelu",
207
+ help="use which activation to activate the block",
208
+ )
209
+
210
+ self.parser = parser
211
+
212
+ def parse(self):
213
+ args = self.parser.parse_args()
214
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
215
+ print(args)
216
+ return args
217
+
218
+
219
+ def test(args):
220
+
221
+ module = LSegModule.load_from_checkpoint(
222
+ checkpoint_path=args.weights,
223
+ data_path=args.data_path,
224
+ dataset=args.dataset,
225
+ backbone=args.backbone,
226
+ aux=args.aux,
227
+ num_features=256,
228
+ aux_weight=0,
229
+ se_loss=False,
230
+ se_weight=0,
231
+ base_lr=0,
232
+ batch_size=1,
233
+ max_epochs=0,
234
+ ignore_index=args.ignore_index,
235
+ dropout=0.0,
236
+ scale_inv=args.scale_inv,
237
+ augment=False,
238
+ no_batchnorm=False,
239
+ widehead=args.widehead,
240
+ widehead_hr=args.widehead_hr,
241
+ map_locatin="cpu",
242
+ arch_option=args.arch_option,
243
+ strict=args.strict,
244
+ block_depth=args.block_depth,
245
+ activation=args.activation,
246
+ )
247
+ input_transform = module.val_transform
248
+ num_classes = module.num_classes
249
+
250
+ # dataset
251
+ testset = get_dataset(
252
+ args.dataset,
253
+ root=args.data_path,
254
+ split="val",
255
+ mode="testval",
256
+ transform=input_transform,
257
+ )
258
+
259
+ # dataloader
260
+ loader_kwargs = (
261
+ {"num_workers": args.workers, "pin_memory": True} if args.cuda else {}
262
+ )
263
+ test_data = data.DataLoader(
264
+ testset,
265
+ batch_size=args.test_batch_size,
266
+ drop_last=False,
267
+ shuffle=False,
268
+ collate_fn=test_batchify_fn,
269
+ **loader_kwargs
270
+ )
271
+
272
+ if isinstance(module.net, BaseNet):
273
+ model = module.net
274
+ else:
275
+ model = module
276
+
277
+ model = model.eval()
278
+ model = model.cpu()
279
+
280
+ print(model)
281
+ if args.acc_bn:
282
+ from encoding.utils.precise_bn import update_bn_stats
283
+
284
+ data_kwargs = {
285
+ "transform": input_transform,
286
+ "base_size": args.base_size,
287
+ "crop_size": args.crop_size,
288
+ }
289
+ trainset = get_dataset(
290
+ args.dataset, split=args.train_split, mode="train", **data_kwargs
291
+ )
292
+ trainloader = data.DataLoader(
293
+ ReturnFirstClosure(trainset),
294
+ root=args.data_path,
295
+ batch_size=args.batch_size,
296
+ drop_last=True,
297
+ shuffle=True,
298
+ **loader_kwargs
299
+ )
300
+ print("Reseting BN statistics")
301
+ model.cuda()
302
+ update_bn_stats(model, trainloader)
303
+
304
+ if args.export:
305
+ torch.save(model.state_dict(), args.export + ".pth")
306
+ return
307
+
308
+ scales = (
309
+ [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]
310
+ if args.dataset == "citys"
311
+ else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
312
+ )
313
+
314
+ evaluator = LSeg_MultiEvalModule(
315
+ model, num_classes, scales=scales, flip=True
316
+ ).cuda()
317
+ evaluator.eval()
318
+
319
+ metric = utils.SegmentationMetric(testset.num_class)
320
+ tbar = tqdm(test_data)
321
+
322
+ f = open("logs/log_test_{}_{}.txt".format(args.jobname, args.dataset), "a+")
323
+ per_class_iou = np.zeros(testset.num_class)
324
+ cnt = 0
325
+ for i, (image, dst) in enumerate(tbar):
326
+ if args.eval:
327
+ with torch.no_grad():
328
+ if False:
329
+ sample = {"image": image[0].cpu().permute(1, 2, 0).numpy()}
330
+ out = torch.zeros(
331
+ 1, testset.num_class, image[0].shape[1], image[0].shape[2]
332
+ ).cuda()
333
+
334
+ H, W = image[0].shape[1], image[0].shape[2]
335
+ for scale in scales:
336
+ long_size = int(math.ceil(520 * scale))
337
+ if H > W:
338
+ height = long_size
339
+ width = int(1.0 * W * long_size / H + 0.5)
340
+ short_size = width
341
+ else:
342
+ width = long_size
343
+ height = int(1.0 * H * long_size / W + 0.5)
344
+ short_size = height
345
+
346
+ rs = Resize(
347
+ width,
348
+ height,
349
+ resize_target=False,
350
+ keep_aspect_ratio=True,
351
+ ensure_multiple_of=32,
352
+ resize_method="minimal",
353
+ image_interpolation_method=cv2.INTER_AREA,
354
+ )
355
+
356
+ inf_image = (
357
+ torch.from_numpy(rs(sample)["image"])
358
+ .cuda()
359
+ .permute(2, 0, 1)
360
+ .unsqueeze(0)
361
+ )
362
+ inf_image = torch.cat((inf_image, torch.fliplr(inf_image)), 0)
363
+ try:
364
+ pred = model(inf_image)
365
+ except:
366
+ print(H, W, sz, i)
367
+ exit()
368
+
369
+ pred0 = F.softmax(pred[0], dim=1)
370
+ pred1 = F.softmax(pred[1], dim=1)
371
+
372
+ pred = pred0 + 0.2 * pred1
373
+
374
+ out += F.interpolate(
375
+ pred.sum(0, keepdim=True),
376
+ (out.shape[2], out.shape[3]),
377
+ mode="bilinear",
378
+ align_corners=True,
379
+ )
380
+
381
+ predicts = [out]
382
+ else:
383
+ predicts = evaluator.parallel_forward(image)
384
+
385
+ metric.update(dst, predicts)
386
+ pixAcc, mIoU = metric.get()
387
+
388
+ _, _, total_inter, total_union = metric.get_all()
389
+ per_class_iou += 1.0 * total_inter / (np.spacing(1) + total_union)
390
+ cnt+=1
391
+
392
+ tbar.set_description("pixAcc: %.4f, mIoU: %.4f" % (pixAcc, mIoU))
393
+ else:
394
+ with torch.no_grad():
395
+ outputs = evaluator.parallel_forward(image)
396
+ predicts = [
397
+ testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
398
+ for output in outputs
399
+ ]
400
+
401
+ # output folder
402
+ outdir = "outdir_ours"
403
+ if not os.path.exists(outdir):
404
+ os.makedirs(outdir)
405
+
406
+ for predict, impath in zip(predicts, dst):
407
+ mask = utils.get_mask_pallete(predict, args.dataset)
408
+ outname = os.path.splitext(impath)[0] + ".png"
409
+ mask.save(os.path.join(outdir, outname))
410
+
411
+ if args.eval:
412
+ each_classes_iou = per_class_iou/cnt
413
+ print("pixAcc: %.4f, mIoU: %.4f" % (pixAcc, mIoU))
414
+ print(each_classes_iou)
415
+ f.write("dataset {} ==> pixAcc: {:.4f}, mIoU: {:.4f}\n".format(args.dataset, pixAcc, mIoU))
416
+ for per_iou in each_classes_iou: f.write('{:.4f}, '.format(per_iou))
417
+ f.write('\n')
418
+
419
+
420
+ class ReturnFirstClosure(object):
421
+ def __init__(self, data):
422
+ self._data = data
423
+
424
+ def __len__(self):
425
+ return len(self._data)
426
+
427
+ def __getitem__(self, idx):
428
+ outputs = self._data[idx]
429
+ return outputs[0]
430
+
431
+
432
+ if __name__ == "__main__":
433
+ args = Options().parse()
434
+ torch.manual_seed(args.seed)
435
+ args.test_batch_size = torch.cuda.device_count()
436
+ test(args)
train_lseg.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from modules.lseg_module import LSegModule
2
+ from utils import do_training, get_default_argument_parser
3
+
4
+ if __name__ == "__main__":
5
+ parser = LSegModule.add_model_specific_args(get_default_argument_parser())
6
+ args = parser.parse_args()
7
+ do_training(args, LSegModule)
utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+
4
+ from glob import glob
5
+
6
+ from argparse import ArgumentParser
7
+ import torch
8
+ import pytorch_lightning as pl
9
+ import numpy as np
10
+ import cv2
11
+ import random
12
+ import math
13
+ from torchvision import transforms
14
+
15
+
16
+ def do_training(hparams, model_constructor):
17
+ # instantiate model
18
+ model = model_constructor(**vars(hparams))
19
+ # set all sorts of training parameters
20
+ hparams.gpus = -1
21
+ hparams.accelerator = "ddp"
22
+ hparams.benchmark = True
23
+
24
+ if hparams.dry_run:
25
+ print("Doing a dry run")
26
+ hparams.overfit_batches = hparams.batch_size
27
+
28
+ if not hparams.no_resume:
29
+ hparams = set_resume_parameters(hparams)
30
+
31
+ if not hasattr(hparams, "version") or hparams.version is None:
32
+ hparams.version = 0
33
+
34
+ hparams.sync_batchnorm = True
35
+
36
+ ttlogger = pl.loggers.TestTubeLogger(
37
+ "checkpoints", name=hparams.exp_name, version=hparams.version
38
+ )
39
+
40
+ hparams.callbacks = make_checkpoint_callbacks(hparams.exp_name, hparams.version)
41
+
42
+ wblogger = get_wandb_logger(hparams)
43
+ hparams.logger = [wblogger, ttlogger]
44
+
45
+ trainer = pl.Trainer.from_argparse_args(hparams)
46
+ trainer.fit(model)
47
+
48
+
49
+ def get_default_argument_parser():
50
+ parser = ArgumentParser(add_help=False)
51
+ parser.add_argument(
52
+ "--num_nodes",
53
+ type=int,
54
+ default=1,
55
+ help="number of nodes for distributed training",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--exp_name", type=str, required=True, help="name your experiment"
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--dry-run",
64
+ action="store_true",
65
+ default=False,
66
+ help="run on batch of train/val/test",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--no_resume",
71
+ action="store_true",
72
+ default=False,
73
+ help="resume if we have a checkpoint",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--accumulate_grad_batches",
78
+ type=int,
79
+ default=1,
80
+ help="accumulate N batches for gradient computation",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--max_epochs", type=int, default=200, help="maximum number of epochs"
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--project_name", type=str, default="lightseg", help="project name for logging"
89
+ )
90
+
91
+ return parser
92
+
93
+
94
+ def make_checkpoint_callbacks(exp_name, version, base_path="checkpoints", frequency=1):
95
+ version = 0 if version is None else version
96
+
97
+ base_callback = pl.callbacks.ModelCheckpoint(
98
+ dirpath=f"{base_path}/{exp_name}/version_{version}/checkpoints/",
99
+ save_last=True,
100
+ verbose=True,
101
+ )
102
+
103
+ val_callback = pl.callbacks.ModelCheckpoint(
104
+ monitor="val_acc_epoch",
105
+ dirpath=f"{base_path}/{exp_name}/version_{version}/checkpoints/",
106
+ filename="result-{epoch}-{val_acc_epoch:.2f}",
107
+ mode="max",
108
+ save_top_k=3,
109
+ verbose=True,
110
+ )
111
+
112
+ return [base_callback, val_callback]
113
+
114
+
115
+ def get_latest_version(folder):
116
+ versions = [
117
+ int(pathlib.PurePath(path).name.split("_")[-1])
118
+ for path in glob(f"{folder}/version_*/")
119
+ ]
120
+
121
+ if len(versions) == 0:
122
+ return None
123
+
124
+ versions.sort()
125
+ return versions[-1]
126
+
127
+
128
+ def get_latest_checkpoint(exp_name, version):
129
+ while version > -1:
130
+ folder = f"./checkpoints/{exp_name}/version_{version}/checkpoints/"
131
+
132
+ latest = f"{folder}/last.ckpt"
133
+ if os.path.exists(latest):
134
+ return latest, version
135
+
136
+ chkpts = glob(f"{folder}/epoch=*.ckpt")
137
+
138
+ if len(chkpts) > 0:
139
+ break
140
+
141
+ version -= 1
142
+
143
+ if len(chkpts) == 0:
144
+ return None, None
145
+
146
+ latest = max(chkpts, key=os.path.getctime)
147
+
148
+ return latest, version
149
+
150
+
151
+ def set_resume_parameters(hparams):
152
+ version = get_latest_version(f"./checkpoints/{hparams.exp_name}")
153
+
154
+ if version is not None:
155
+ latest, version = get_latest_checkpoint(hparams.exp_name, version)
156
+ print(f"Resuming checkpoint {latest}, exp_version={version}")
157
+
158
+ hparams.resume_from_checkpoint = latest
159
+ hparams.version = version
160
+
161
+ wandb_file = "checkpoints/{hparams.exp_name}/version_{version}/wandb_id"
162
+ if os.path.exists(wandb_file):
163
+ with open(wandb_file, "r") as f:
164
+ hparams.wandb_id = f.read()
165
+ else:
166
+ version = 0
167
+
168
+ return hparams
169
+
170
+
171
+ def get_wandb_logger(hparams):
172
+ exp_dir = f"checkpoints/{hparams.exp_name}/version_{hparams.version}/"
173
+ id_file = f"{exp_dir}/wandb_id"
174
+
175
+ if os.path.exists(id_file):
176
+ with open(id_file) as f:
177
+ hparams.wandb_id = f.read()
178
+ else:
179
+ hparams.wandb_id = None
180
+
181
+ logger = pl.loggers.WandbLogger(
182
+ save_dir="checkpoints",
183
+ project=hparams.project_name,
184
+ name=hparams.exp_name,
185
+ id=hparams.wandb_id,
186
+ )
187
+
188
+ if hparams.wandb_id is None:
189
+ _ = logger.experiment
190
+
191
+ if not os.path.exists(exp_dir):
192
+ os.makedirs(exp_dir)
193
+
194
+ with open(id_file, "w") as f:
195
+ f.write(logger.version)
196
+
197
+ return logger
198
+
199
+
200
+ class Resize(object):
201
+ """Resize sample to given size (width, height)."""
202
+
203
+ def __init__(
204
+ self,
205
+ width,
206
+ height,
207
+ resize_target=True,
208
+ keep_aspect_ratio=False,
209
+ ensure_multiple_of=1,
210
+ resize_method="lower_bound",
211
+ image_interpolation_method=cv2.INTER_AREA,
212
+ letter_box=False,
213
+ ):
214
+ """Init.
215
+
216
+ Args:
217
+ width (int): desired output width
218
+ height (int): desired output height
219
+ resize_target (bool, optional):
220
+ True: Resize the full sample (image, mask, target).
221
+ False: Resize image only.
222
+ Defaults to True.
223
+ keep_aspect_ratio (bool, optional):
224
+ True: Keep the aspect ratio of the input sample.
225
+ Output sample might not have the given width and height, and
226
+ resize behaviour depends on the parameter 'resize_method'.
227
+ Defaults to False.
228
+ ensure_multiple_of (int, optional):
229
+ Output width and height is constrained to be multiple of this parameter.
230
+ Defaults to 1.
231
+ resize_method (str, optional):
232
+ "lower_bound": Output will be at least as large as the given size.
233
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
234
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
235
+ Defaults to "lower_bound".
236
+ """
237
+ self.__width = width
238
+ self.__height = height
239
+
240
+ self.__resize_target = resize_target
241
+ self.__keep_aspect_ratio = keep_aspect_ratio
242
+ self.__multiple_of = ensure_multiple_of
243
+ self.__resize_method = resize_method
244
+ self.__image_interpolation_method = image_interpolation_method
245
+ self.__letter_box = letter_box
246
+
247
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
248
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
249
+
250
+ if max_val is not None and y > max_val:
251
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
252
+
253
+ if y < min_val:
254
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
255
+
256
+ return y
257
+
258
+ def get_size(self, width, height):
259
+ # determine new height and width
260
+ scale_height = self.__height / height
261
+ scale_width = self.__width / width
262
+
263
+ if self.__keep_aspect_ratio:
264
+ if self.__resize_method == "lower_bound":
265
+ # scale such that output size is lower bound
266
+ if scale_width > scale_height:
267
+ # fit width
268
+ scale_height = scale_width
269
+ else:
270
+ # fit height
271
+ scale_width = scale_height
272
+ elif self.__resize_method == "upper_bound":
273
+ # scale such that output size is upper bound
274
+ if scale_width < scale_height:
275
+ # fit width
276
+ scale_height = scale_width
277
+ else:
278
+ # fit height
279
+ scale_width = scale_height
280
+ elif self.__resize_method == "minimal":
281
+ # scale as least as possbile
282
+ if abs(1 - scale_width) < abs(1 - scale_height):
283
+ # fit width
284
+ scale_height = scale_width
285
+ else:
286
+ # fit height
287
+ scale_width = scale_height
288
+ else:
289
+ raise ValueError(
290
+ f"resize_method {self.__resize_method} not implemented"
291
+ )
292
+
293
+ if self.__resize_method == "lower_bound":
294
+ new_height = self.constrain_to_multiple_of(
295
+ scale_height * height, min_val=self.__height
296
+ )
297
+ new_width = self.constrain_to_multiple_of(
298
+ scale_width * width, min_val=self.__width
299
+ )
300
+ elif self.__resize_method == "upper_bound":
301
+ new_height = self.constrain_to_multiple_of(
302
+ scale_height * height, max_val=self.__height
303
+ )
304
+ new_width = self.constrain_to_multiple_of(
305
+ scale_width * width, max_val=self.__width
306
+ )
307
+ elif self.__resize_method == "minimal":
308
+ new_height = self.constrain_to_multiple_of(scale_height * height)
309
+ new_width = self.constrain_to_multiple_of(scale_width * width)
310
+ else:
311
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
312
+
313
+ return (new_width, new_height)
314
+
315
+ def make_letter_box(self, sample):
316
+ top = bottom = (self.__height - sample.shape[0]) // 2
317
+ left = right = (self.__width - sample.shape[1]) // 2
318
+ sample = cv2.copyMakeBorder(
319
+ sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0
320
+ )
321
+ return sample
322
+
323
+ def __call__(self, sample):
324
+ width, height = self.get_size(
325
+ sample["image"].shape[1], sample["image"].shape[0]
326
+ )
327
+
328
+ # resize sample
329
+ sample["image"] = cv2.resize(
330
+ sample["image"],
331
+ (width, height),
332
+ interpolation=self.__image_interpolation_method,
333
+ )
334
+
335
+ if self.__letter_box:
336
+ sample["image"] = self.make_letter_box(sample["image"])
337
+
338
+ if self.__resize_target:
339
+ if "disparity" in sample:
340
+ sample["disparity"] = cv2.resize(
341
+ sample["disparity"],
342
+ (width, height),
343
+ interpolation=cv2.INTER_NEAREST,
344
+ )
345
+
346
+ if self.__letter_box:
347
+ sample["disparity"] = self.make_letter_box(sample["disparity"])
348
+
349
+ if "depth" in sample:
350
+ sample["depth"] = cv2.resize(
351
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
352
+ )
353
+
354
+ if self.__letter_box:
355
+ sample["depth"] = self.make_letter_box(sample["depth"])
356
+
357
+ sample["mask"] = cv2.resize(
358
+ sample["mask"].astype(np.float32),
359
+ (width, height),
360
+ interpolation=cv2.INTER_NEAREST,
361
+ )
362
+
363
+ if self.__letter_box:
364
+ sample["mask"] = self.make_letter_box(sample["mask"])
365
+
366
+ sample["mask"] = sample["mask"].astype(bool)
367
+
368
+ return sample