add files
Browse files- LICENSE +21 -0
- additional_utils/encoding_models.py +164 -0
- additional_utils/models.py +250 -0
- data/__init__.py +24 -0
- label_files/ade20k_objectInfo150.txt +151 -0
- lseg_app.py +386 -0
- modules/lseg_module.py +183 -0
- modules/lsegmentation_module.py +304 -0
- modules/models/lseg_blocks.py +359 -0
- modules/models/lseg_net.py +231 -0
- modules/models/lseg_vit.py +535 -0
- prepare_ade20k.py +45 -0
- test_lseg.py +436 -0
- train_lseg.py +7 -0
- utils.py +368 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
additional_utils/encoding_models.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# Referred to: https://github.com/zhanghang1989/PyTorch-Encoding
|
3 |
+
###########################################################################
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
11 |
+
from torch.nn.parallel.scatter_gather import scatter
|
12 |
+
import threading
|
13 |
+
import torch
|
14 |
+
from torch.cuda._utils import _get_device_index
|
15 |
+
from torch.cuda.amp import autocast
|
16 |
+
from torch._utils import ExceptionWrapper
|
17 |
+
|
18 |
+
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
|
19 |
+
|
20 |
+
__all__ = ['MultiEvalModule']
|
21 |
+
|
22 |
+
class MultiEvalModule(DataParallel):
|
23 |
+
"""Multi-size Segmentation Eavluator"""
|
24 |
+
def __init__(self, module, nclass, device_ids=None, flip=True,
|
25 |
+
scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
|
26 |
+
super(MultiEvalModule, self).__init__(module, device_ids)
|
27 |
+
self.nclass = nclass
|
28 |
+
self.base_size = module.base_size
|
29 |
+
self.crop_size = module.crop_size
|
30 |
+
self.scales = scales
|
31 |
+
self.flip = flip
|
32 |
+
print('MultiEvalModule: base_size {}, crop_size {}'. \
|
33 |
+
format(self.base_size, self.crop_size))
|
34 |
+
|
35 |
+
def parallel_forward(self, inputs, **kwargs):
|
36 |
+
"""Multi-GPU Mult-size Evaluation
|
37 |
+
|
38 |
+
Args:
|
39 |
+
inputs: list of Tensors
|
40 |
+
"""
|
41 |
+
inputs = [(input.unsqueeze(0).cuda(device),)
|
42 |
+
for input, device in zip(inputs, self.device_ids)]
|
43 |
+
replicas = self.replicate(self, self.device_ids[:len(inputs)])
|
44 |
+
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
|
45 |
+
if len(inputs) < len(kwargs):
|
46 |
+
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
47 |
+
elif len(kwargs) < len(inputs):
|
48 |
+
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
49 |
+
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
50 |
+
#for out in outputs:
|
51 |
+
# print('out.size()', out.size())
|
52 |
+
return outputs
|
53 |
+
|
54 |
+
def forward(self, image):
|
55 |
+
"""Mult-size Evaluation"""
|
56 |
+
# only single image is supported for evaluation
|
57 |
+
batch, _, h, w = image.size()
|
58 |
+
assert(batch == 1)
|
59 |
+
stride_rate = 2.0/3.0
|
60 |
+
crop_size = self.crop_size
|
61 |
+
stride = int(crop_size * stride_rate)
|
62 |
+
with torch.cuda.device_of(image):
|
63 |
+
scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
|
64 |
+
|
65 |
+
for scale in self.scales:
|
66 |
+
long_size = int(math.ceil(self.base_size * scale))
|
67 |
+
if h > w:
|
68 |
+
height = long_size
|
69 |
+
width = int(1.0 * w * long_size / h + 0.5)
|
70 |
+
short_size = width
|
71 |
+
else:
|
72 |
+
width = long_size
|
73 |
+
height = int(1.0 * h * long_size / w + 0.5)
|
74 |
+
short_size = height
|
75 |
+
"""
|
76 |
+
short_size = int(math.ceil(self.base_size * scale))
|
77 |
+
if h > w:
|
78 |
+
width = short_size
|
79 |
+
height = int(1.0 * h * short_size / w)
|
80 |
+
long_size = height
|
81 |
+
else:
|
82 |
+
height = short_size
|
83 |
+
width = int(1.0 * w * short_size / h)
|
84 |
+
long_size = width
|
85 |
+
"""
|
86 |
+
# resize image to current size
|
87 |
+
cur_img = resize_image(image, height, width, **self.module._up_kwargs)
|
88 |
+
if long_size <= crop_size:
|
89 |
+
pad_img = pad_image(cur_img, self.module.mean,
|
90 |
+
self.module.std, crop_size)
|
91 |
+
outputs = module_inference(self.module, pad_img, self.flip)
|
92 |
+
outputs = crop_image(outputs, 0, height, 0, width)
|
93 |
+
else:
|
94 |
+
if short_size < crop_size:
|
95 |
+
# pad if needed
|
96 |
+
pad_img = pad_image(cur_img, self.module.mean,
|
97 |
+
self.module.std, crop_size)
|
98 |
+
else:
|
99 |
+
pad_img = cur_img
|
100 |
+
_,_,ph,pw = pad_img.size()
|
101 |
+
assert(ph >= height and pw >= width)
|
102 |
+
# grid forward and normalize
|
103 |
+
h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
|
104 |
+
w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
|
105 |
+
with torch.cuda.device_of(image):
|
106 |
+
outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
|
107 |
+
count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
|
108 |
+
# grid evaluation
|
109 |
+
for idh in range(h_grids):
|
110 |
+
for idw in range(w_grids):
|
111 |
+
h0 = idh * stride
|
112 |
+
w0 = idw * stride
|
113 |
+
h1 = min(h0 + crop_size, ph)
|
114 |
+
w1 = min(w0 + crop_size, pw)
|
115 |
+
crop_img = crop_image(pad_img, h0, h1, w0, w1)
|
116 |
+
# pad if needed
|
117 |
+
pad_crop_img = pad_image(crop_img, self.module.mean,
|
118 |
+
self.module.std, crop_size)
|
119 |
+
output = module_inference(self.module, pad_crop_img, self.flip)
|
120 |
+
outputs[:,:,h0:h1,w0:w1] += crop_image(output,
|
121 |
+
0, h1-h0, 0, w1-w0)
|
122 |
+
count_norm[:,:,h0:h1,w0:w1] += 1
|
123 |
+
assert((count_norm==0).sum()==0)
|
124 |
+
outputs = outputs / count_norm
|
125 |
+
outputs = outputs[:,:,:height,:width]
|
126 |
+
|
127 |
+
score = resize_image(outputs, h, w, **self.module._up_kwargs)
|
128 |
+
scores += score
|
129 |
+
|
130 |
+
return scores
|
131 |
+
|
132 |
+
|
133 |
+
def module_inference(module, image, flip=True):
|
134 |
+
output = module.evaluate(image)
|
135 |
+
if flip:
|
136 |
+
fimg = flip_image(image)
|
137 |
+
foutput = module.evaluate(fimg)
|
138 |
+
output += flip_image(foutput)
|
139 |
+
return output
|
140 |
+
|
141 |
+
def resize_image(img, h, w, **up_kwargs):
|
142 |
+
return F.interpolate(img, (h, w), **up_kwargs)
|
143 |
+
|
144 |
+
def pad_image(img, mean, std, crop_size):
|
145 |
+
b,c,h,w = img.size()
|
146 |
+
assert(c==3)
|
147 |
+
padh = crop_size - h if h < crop_size else 0
|
148 |
+
padw = crop_size - w if w < crop_size else 0
|
149 |
+
pad_values = -np.array(mean) / np.array(std)
|
150 |
+
img_pad = img.new().resize_(b,c,h+padh,w+padw)
|
151 |
+
for i in range(c):
|
152 |
+
# note that pytorch pad params is in reversed orders
|
153 |
+
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
|
154 |
+
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
|
155 |
+
return img_pad
|
156 |
+
|
157 |
+
def crop_image(img, h0, h1, w0, w1):
|
158 |
+
return img[:,:,h0:h1,w0:w1]
|
159 |
+
|
160 |
+
def flip_image(img):
|
161 |
+
assert(img.dim()==4)
|
162 |
+
with torch.cuda.device_of(img):
|
163 |
+
idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
|
164 |
+
return img.index_select(3, idx)
|
additional_utils/models.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# Referred to: https://github.com/zhanghang1989/PyTorch-Encoding
|
3 |
+
###########################################################################
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
11 |
+
from torch.nn.parallel.scatter_gather import scatter
|
12 |
+
import threading
|
13 |
+
import torch
|
14 |
+
from torch.cuda._utils import _get_device_index
|
15 |
+
from torch.cuda.amp import autocast
|
16 |
+
from torch._utils import ExceptionWrapper
|
17 |
+
|
18 |
+
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
|
19 |
+
|
20 |
+
__all__ = ['LSeg_MultiEvalModule']
|
21 |
+
|
22 |
+
|
23 |
+
class LSeg_MultiEvalModule(DataParallel):
|
24 |
+
"""Multi-size Segmentation Eavluator"""
|
25 |
+
def __init__(self, module, device_ids=None, flip=True,
|
26 |
+
scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
|
27 |
+
super(LSeg_MultiEvalModule, self).__init__(module, device_ids)
|
28 |
+
self.base_size = module.base_size
|
29 |
+
self.crop_size = module.crop_size
|
30 |
+
self.scales = scales
|
31 |
+
self.flip = flip
|
32 |
+
print('MultiEvalModule: base_size {}, crop_size {}'. \
|
33 |
+
format(self.base_size, self.crop_size))
|
34 |
+
|
35 |
+
def parallel_forward(self, inputs, label_set='', **kwargs):
|
36 |
+
"""Multi-GPU Mult-size Evaluation
|
37 |
+
|
38 |
+
Args:
|
39 |
+
inputs: list of Tensors
|
40 |
+
"""
|
41 |
+
if len(label_set) < 10:
|
42 |
+
print('** MultiEvalModule parallel_forward phase: {} **'.format(label_set))
|
43 |
+
self.nclass = len(label_set)
|
44 |
+
inputs = [(input.unsqueeze(0).cuda(device),)
|
45 |
+
for input, device in zip(inputs, self.device_ids)]
|
46 |
+
replicas = self.replicate(self, self.device_ids[:len(inputs)])
|
47 |
+
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
|
48 |
+
if len(inputs) < len(kwargs):
|
49 |
+
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
50 |
+
elif len(kwargs) < len(inputs):
|
51 |
+
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
52 |
+
outputs = parallel_apply(replicas, inputs, label_set, kwargs)
|
53 |
+
return outputs
|
54 |
+
|
55 |
+
def forward(self, image, label_set=''):
|
56 |
+
"""Mult-size Evaluation"""
|
57 |
+
# only single image is supported for evaluation
|
58 |
+
if len(label_set) < 10:
|
59 |
+
print('** MultiEvalModule forward phase: {} **'.format(label_set))
|
60 |
+
batch, _, h, w = image.size()
|
61 |
+
assert(batch == 1)
|
62 |
+
self.nclass = len(label_set)
|
63 |
+
stride_rate = 2.0/3.0
|
64 |
+
crop_size = self.crop_size
|
65 |
+
stride = int(crop_size * stride_rate)
|
66 |
+
with torch.cuda.device_of(image):
|
67 |
+
scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
|
68 |
+
|
69 |
+
for scale in self.scales:
|
70 |
+
long_size = int(math.ceil(self.base_size * scale))
|
71 |
+
if h > w:
|
72 |
+
height = long_size
|
73 |
+
width = int(1.0 * w * long_size / h + 0.5)
|
74 |
+
short_size = width
|
75 |
+
else:
|
76 |
+
width = long_size
|
77 |
+
height = int(1.0 * h * long_size / w + 0.5)
|
78 |
+
short_size = height
|
79 |
+
"""
|
80 |
+
short_size = int(math.ceil(self.base_size * scale))
|
81 |
+
if h > w:
|
82 |
+
width = short_size
|
83 |
+
height = int(1.0 * h * short_size / w)
|
84 |
+
long_size = height
|
85 |
+
else:
|
86 |
+
height = short_size
|
87 |
+
width = int(1.0 * w * short_size / h)
|
88 |
+
long_size = width
|
89 |
+
"""
|
90 |
+
# resize image to current size
|
91 |
+
cur_img = resize_image(image, height, width, **self.module._up_kwargs)
|
92 |
+
if long_size <= crop_size:
|
93 |
+
pad_img = pad_image(cur_img, self.module.mean,
|
94 |
+
self.module.std, crop_size)
|
95 |
+
outputs = module_inference(self.module, pad_img, label_set, self.flip)
|
96 |
+
outputs = crop_image(outputs, 0, height, 0, width)
|
97 |
+
else:
|
98 |
+
if short_size < crop_size:
|
99 |
+
# pad if needed
|
100 |
+
pad_img = pad_image(cur_img, self.module.mean,
|
101 |
+
self.module.std, crop_size)
|
102 |
+
else:
|
103 |
+
pad_img = cur_img
|
104 |
+
_,_,ph,pw = pad_img.shape #.size()
|
105 |
+
assert(ph >= height and pw >= width)
|
106 |
+
# grid forward and normalize
|
107 |
+
h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
|
108 |
+
w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
|
109 |
+
with torch.cuda.device_of(image):
|
110 |
+
outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
|
111 |
+
count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
|
112 |
+
# grid evaluation
|
113 |
+
for idh in range(h_grids):
|
114 |
+
for idw in range(w_grids):
|
115 |
+
h0 = idh * stride
|
116 |
+
w0 = idw * stride
|
117 |
+
h1 = min(h0 + crop_size, ph)
|
118 |
+
w1 = min(w0 + crop_size, pw)
|
119 |
+
crop_img = crop_image(pad_img, h0, h1, w0, w1)
|
120 |
+
# pad if needed
|
121 |
+
pad_crop_img = pad_image(crop_img, self.module.mean,
|
122 |
+
self.module.std, crop_size)
|
123 |
+
output = module_inference(self.module, pad_crop_img, label_set, self.flip)
|
124 |
+
outputs[:,:,h0:h1,w0:w1] += crop_image(output,
|
125 |
+
0, h1-h0, 0, w1-w0)
|
126 |
+
count_norm[:,:,h0:h1,w0:w1] += 1
|
127 |
+
assert((count_norm==0).sum()==0)
|
128 |
+
outputs = outputs / count_norm
|
129 |
+
outputs = outputs[:,:,:height,:width]
|
130 |
+
score = resize_image(outputs, h, w, **self.module._up_kwargs)
|
131 |
+
scores += score
|
132 |
+
return scores
|
133 |
+
|
134 |
+
def module_inference(module, image, label_set, flip=True):
|
135 |
+
output = module.evaluate_random(image, label_set)
|
136 |
+
if flip:
|
137 |
+
fimg = flip_image(image)
|
138 |
+
foutput = module.evaluate_random(fimg, label_set)
|
139 |
+
output += flip_image(foutput)
|
140 |
+
return output
|
141 |
+
|
142 |
+
def resize_image(img, h, w, **up_kwargs):
|
143 |
+
return F.interpolate(img, (h, w), **up_kwargs)
|
144 |
+
|
145 |
+
def pad_image(img, mean, std, crop_size):
|
146 |
+
b,c,h,w = img.shape #.size()
|
147 |
+
assert(c==3)
|
148 |
+
padh = crop_size - h if h < crop_size else 0
|
149 |
+
padw = crop_size - w if w < crop_size else 0
|
150 |
+
pad_values = -np.array(mean) / np.array(std)
|
151 |
+
img_pad = img.new().resize_(b,c,h+padh,w+padw)
|
152 |
+
for i in range(c):
|
153 |
+
# note that pytorch pad params is in reversed orders
|
154 |
+
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
|
155 |
+
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
|
156 |
+
return img_pad
|
157 |
+
|
158 |
+
def crop_image(img, h0, h1, w0, w1):
|
159 |
+
return img[:,:,h0:h1,w0:w1]
|
160 |
+
|
161 |
+
def flip_image(img):
|
162 |
+
assert(img.dim()==4)
|
163 |
+
with torch.cuda.device_of(img):
|
164 |
+
idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
|
165 |
+
return img.index_select(3, idx)
|
166 |
+
|
167 |
+
|
168 |
+
def get_a_var(obj):
|
169 |
+
if isinstance(obj, torch.Tensor):
|
170 |
+
return obj
|
171 |
+
|
172 |
+
if isinstance(obj, list) or isinstance(obj, tuple):
|
173 |
+
for result in map(get_a_var, obj):
|
174 |
+
if isinstance(result, torch.Tensor):
|
175 |
+
return result
|
176 |
+
if isinstance(obj, dict):
|
177 |
+
for result in map(get_a_var, obj.items()):
|
178 |
+
if isinstance(result, torch.Tensor):
|
179 |
+
return result
|
180 |
+
return None
|
181 |
+
|
182 |
+
|
183 |
+
def parallel_apply(modules, inputs, label_set, kwargs_tup=None, devices=None):
|
184 |
+
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
185 |
+
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
186 |
+
on each of :attr:`devices`.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
modules (Module): modules to be parallelized
|
190 |
+
inputs (tensor): inputs to the modules
|
191 |
+
devices (list of int or torch.device): CUDA devices
|
192 |
+
|
193 |
+
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
194 |
+
:attr:`devices` (if given) should all have same length. Moreover, each
|
195 |
+
element of :attr:`inputs` can either be a single object as the only argument
|
196 |
+
to a module, or a collection of positional arguments.
|
197 |
+
"""
|
198 |
+
assert len(modules) == len(inputs)
|
199 |
+
if kwargs_tup is not None:
|
200 |
+
assert len(modules) == len(kwargs_tup)
|
201 |
+
else:
|
202 |
+
kwargs_tup = ({},) * len(modules)
|
203 |
+
if devices is not None:
|
204 |
+
assert len(modules) == len(devices)
|
205 |
+
else:
|
206 |
+
devices = [None] * len(modules)
|
207 |
+
devices = [_get_device_index(x, True) for x in devices]
|
208 |
+
lock = threading.Lock()
|
209 |
+
results = {}
|
210 |
+
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
|
211 |
+
|
212 |
+
def _worker(i, module, input, label_set, kwargs, device=None):
|
213 |
+
torch.set_grad_enabled(grad_enabled)
|
214 |
+
if device is None:
|
215 |
+
device = get_a_var(input).get_device()
|
216 |
+
try:
|
217 |
+
with torch.cuda.device(device), autocast(enabled=autocast_enabled):
|
218 |
+
# this also avoids accidental slicing of `input` if it is a Tensor
|
219 |
+
if not isinstance(input, (list, tuple)):
|
220 |
+
input = (input,)
|
221 |
+
output = module(*input, label_set, **kwargs)
|
222 |
+
with lock:
|
223 |
+
results[i] = output
|
224 |
+
except Exception:
|
225 |
+
with lock:
|
226 |
+
results[i] = ExceptionWrapper(
|
227 |
+
where="in replica {} on device {}".format(i, device))
|
228 |
+
|
229 |
+
if len(modules) > 1:
|
230 |
+
threads = [threading.Thread(target=_worker,
|
231 |
+
args=(i, module, input, label_set, kwargs, device))
|
232 |
+
for i, (module, input, kwargs, device) in
|
233 |
+
enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
234 |
+
|
235 |
+
for thread in threads:
|
236 |
+
thread.start()
|
237 |
+
for thread in threads:
|
238 |
+
thread.join()
|
239 |
+
else:
|
240 |
+
_worker(0, modules[0], inputs[0], label_set, kwargs_tup[0], devices[0])
|
241 |
+
|
242 |
+
outputs = []
|
243 |
+
for i in range(len(inputs)):
|
244 |
+
output = results[i]
|
245 |
+
if isinstance(output, ExceptionWrapper):
|
246 |
+
output.reraise()
|
247 |
+
outputs.append(output)
|
248 |
+
return outputs
|
249 |
+
|
250 |
+
|
data/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import functools
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.utils.data
|
8 |
+
import torchvision.transforms as torch_transforms
|
9 |
+
import encoding.datasets as enc_ds
|
10 |
+
|
11 |
+
encoding_datasets = {
|
12 |
+
x: functools.partial(enc_ds.get_dataset, x)
|
13 |
+
for x in ["coco", "ade20k", "pascal_voc", "pascal_aug", "pcontext", "citys"]
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
def get_dataset(name, **kwargs):
|
18 |
+
if name in encoding_datasets:
|
19 |
+
return encoding_datasets[name.lower()](**kwargs)
|
20 |
+
assert False, f"dataset {name} not found"
|
21 |
+
|
22 |
+
|
23 |
+
def get_available_datasets():
|
24 |
+
return list(encoding_datasets.keys())
|
label_files/ade20k_objectInfo150.txt
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Idx,Ratio,Train,Val,Stuff,Name
|
2 |
+
1,0.1576,11664,1172,1,wall
|
3 |
+
2,0.1072,6046,612,1,building;edifice
|
4 |
+
3,0.0878,8265,796,1,sky
|
5 |
+
4,0.0621,9336,917,1,floor;flooring
|
6 |
+
5,0.0480,6678,641,0,tree
|
7 |
+
6,0.0450,6604,643,1,ceiling
|
8 |
+
7,0.0398,4023,408,1,road;route
|
9 |
+
8,0.0231,1906,199,0,bed
|
10 |
+
9,0.0198,4688,460,0,windowpane;window
|
11 |
+
10,0.0183,2423,225,1,grass
|
12 |
+
11,0.0181,2874,294,0,cabinet
|
13 |
+
12,0.0166,3068,310,1,sidewalk;pavement
|
14 |
+
13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
|
15 |
+
14,0.0151,1804,190,1,earth;ground
|
16 |
+
15,0.0118,6666,796,0,door;double;door
|
17 |
+
16,0.0110,4269,411,0,table
|
18 |
+
17,0.0109,1691,160,1,mountain;mount
|
19 |
+
18,0.0104,3999,441,0,plant;flora;plant;life
|
20 |
+
19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
|
21 |
+
20,0.0103,3261,318,0,chair
|
22 |
+
21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
|
23 |
+
22,0.0074,709,75,1,water
|
24 |
+
23,0.0067,3296,315,0,painting;picture
|
25 |
+
24,0.0065,1191,106,0,sofa;couch;lounge
|
26 |
+
25,0.0061,1516,162,0,shelf
|
27 |
+
26,0.0060,667,69,1,house
|
28 |
+
27,0.0053,651,57,1,sea
|
29 |
+
28,0.0052,1847,224,0,mirror
|
30 |
+
29,0.0046,1158,128,1,rug;carpet;carpeting
|
31 |
+
30,0.0044,480,44,1,field
|
32 |
+
31,0.0044,1172,98,0,armchair
|
33 |
+
32,0.0044,1292,184,0,seat
|
34 |
+
33,0.0033,1386,138,0,fence;fencing
|
35 |
+
34,0.0031,698,61,0,desk
|
36 |
+
35,0.0030,781,73,0,rock;stone
|
37 |
+
36,0.0027,380,43,0,wardrobe;closet;press
|
38 |
+
37,0.0026,3089,302,0,lamp
|
39 |
+
38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
|
40 |
+
39,0.0024,804,99,0,railing;rail
|
41 |
+
40,0.0023,1453,153,0,cushion
|
42 |
+
41,0.0023,411,37,0,base;pedestal;stand
|
43 |
+
42,0.0022,1440,162,0,box
|
44 |
+
43,0.0022,800,77,0,column;pillar
|
45 |
+
44,0.0020,2650,298,0,signboard;sign
|
46 |
+
45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
|
47 |
+
46,0.0019,367,36,0,counter
|
48 |
+
47,0.0018,311,30,1,sand
|
49 |
+
48,0.0018,1181,122,0,sink
|
50 |
+
49,0.0018,287,23,1,skyscraper
|
51 |
+
50,0.0018,468,38,0,fireplace;hearth;open;fireplace
|
52 |
+
51,0.0018,402,43,0,refrigerator;icebox
|
53 |
+
52,0.0018,130,12,1,grandstand;covered;stand
|
54 |
+
53,0.0018,561,64,1,path
|
55 |
+
54,0.0017,880,102,0,stairs;steps
|
56 |
+
55,0.0017,86,12,1,runway
|
57 |
+
56,0.0017,172,11,0,case;display;case;showcase;vitrine
|
58 |
+
57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
|
59 |
+
58,0.0017,930,109,0,pillow
|
60 |
+
59,0.0015,139,18,0,screen;door;screen
|
61 |
+
60,0.0015,564,52,1,stairway;staircase
|
62 |
+
61,0.0015,320,26,1,river
|
63 |
+
62,0.0015,261,29,1,bridge;span
|
64 |
+
63,0.0014,275,22,0,bookcase
|
65 |
+
64,0.0014,335,60,0,blind;screen
|
66 |
+
65,0.0014,792,75,0,coffee;table;cocktail;table
|
67 |
+
66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
|
68 |
+
67,0.0014,1309,138,0,flower
|
69 |
+
68,0.0013,1112,113,0,book
|
70 |
+
69,0.0013,266,27,1,hill
|
71 |
+
70,0.0013,659,66,0,bench
|
72 |
+
71,0.0012,331,31,0,countertop
|
73 |
+
72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
|
74 |
+
73,0.0012,369,36,0,palm;palm;tree
|
75 |
+
74,0.0012,144,9,0,kitchen;island
|
76 |
+
75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
|
77 |
+
76,0.0010,324,33,0,swivel;chair
|
78 |
+
77,0.0009,304,27,0,boat
|
79 |
+
78,0.0009,170,20,0,bar
|
80 |
+
79,0.0009,68,6,0,arcade;machine
|
81 |
+
80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
|
82 |
+
81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
|
83 |
+
82,0.0008,492,49,0,towel
|
84 |
+
83,0.0008,2510,269,0,light;light;source
|
85 |
+
84,0.0008,440,39,0,truck;motortruck
|
86 |
+
85,0.0008,147,18,1,tower
|
87 |
+
86,0.0008,583,56,0,chandelier;pendant;pendent
|
88 |
+
87,0.0007,533,61,0,awning;sunshade;sunblind
|
89 |
+
88,0.0007,1989,239,0,streetlight;street;lamp
|
90 |
+
89,0.0007,71,5,0,booth;cubicle;stall;kiosk
|
91 |
+
90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
|
92 |
+
91,0.0007,135,12,0,airplane;aeroplane;plane
|
93 |
+
92,0.0007,83,5,1,dirt;track
|
94 |
+
93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
|
95 |
+
94,0.0006,1003,104,0,pole
|
96 |
+
95,0.0006,182,12,1,land;ground;soil
|
97 |
+
96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
|
98 |
+
97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
|
99 |
+
98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
|
100 |
+
99,0.0006,965,114,0,bottle
|
101 |
+
100,0.0006,117,13,0,buffet;counter;sideboard
|
102 |
+
101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
|
103 |
+
102,0.0006,108,9,1,stage
|
104 |
+
103,0.0006,557,55,0,van
|
105 |
+
104,0.0006,52,4,0,ship
|
106 |
+
105,0.0005,99,5,0,fountain
|
107 |
+
106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
|
108 |
+
107,0.0005,292,31,0,canopy
|
109 |
+
108,0.0005,77,9,0,washer;automatic;washer;washing;machine
|
110 |
+
109,0.0005,340,38,0,plaything;toy
|
111 |
+
110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
|
112 |
+
111,0.0005,465,49,0,stool
|
113 |
+
112,0.0005,50,4,0,barrel;cask
|
114 |
+
113,0.0005,622,75,0,basket;handbasket
|
115 |
+
114,0.0005,80,9,1,waterfall;falls
|
116 |
+
115,0.0005,59,3,0,tent;collapsible;shelter
|
117 |
+
116,0.0005,531,72,0,bag
|
118 |
+
117,0.0005,282,30,0,minibike;motorbike
|
119 |
+
118,0.0005,73,7,0,cradle
|
120 |
+
119,0.0005,435,44,0,oven
|
121 |
+
120,0.0005,136,25,0,ball
|
122 |
+
121,0.0005,116,24,0,food;solid;food
|
123 |
+
122,0.0004,266,31,0,step;stair
|
124 |
+
123,0.0004,58,12,0,tank;storage;tank
|
125 |
+
124,0.0004,418,83,0,trade;name;brand;name;brand;marque
|
126 |
+
125,0.0004,319,43,0,microwave;microwave;oven
|
127 |
+
126,0.0004,1193,139,0,pot;flowerpot
|
128 |
+
127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
|
129 |
+
128,0.0004,347,36,0,bicycle;bike;wheel;cycle
|
130 |
+
129,0.0004,52,5,1,lake
|
131 |
+
130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
|
132 |
+
131,0.0004,108,13,0,screen;silver;screen;projection;screen
|
133 |
+
132,0.0004,201,30,0,blanket;cover
|
134 |
+
133,0.0004,285,21,0,sculpture
|
135 |
+
134,0.0004,268,27,0,hood;exhaust;hood
|
136 |
+
135,0.0003,1020,108,0,sconce
|
137 |
+
136,0.0003,1282,122,0,vase
|
138 |
+
137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
|
139 |
+
138,0.0003,453,57,0,tray
|
140 |
+
139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
|
141 |
+
140,0.0003,397,44,0,fan
|
142 |
+
141,0.0003,92,8,1,pier;wharf;wharfage;dock
|
143 |
+
142,0.0003,228,18,0,crt;screen
|
144 |
+
143,0.0003,570,59,0,plate
|
145 |
+
144,0.0003,217,22,0,monitor;monitoring;device
|
146 |
+
145,0.0003,206,19,0,bulletin;board;notice;board
|
147 |
+
146,0.0003,130,14,0,shower
|
148 |
+
147,0.0003,178,28,0,radiator
|
149 |
+
148,0.0002,504,57,0,glass;drinking;glass
|
150 |
+
149,0.0002,775,96,0,clock
|
151 |
+
150,0.0002,421,56,0,flag
|
lseg_app.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import altair as alt
|
3 |
+
import math
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
st.set_page_config(layout="wide")
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
+
|
13 |
+
import os
|
14 |
+
import argparse
|
15 |
+
import numpy as np
|
16 |
+
from tqdm import tqdm
|
17 |
+
from collections import OrderedDict
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.utils import data
|
22 |
+
import torchvision.transforms as transform
|
23 |
+
from torch.nn.parallel.scatter_gather import gather
|
24 |
+
|
25 |
+
from additional_utils.models import LSeg_MultiEvalModule
|
26 |
+
from modules.lseg_module import LSegModule
|
27 |
+
|
28 |
+
import cv2
|
29 |
+
import math
|
30 |
+
import types
|
31 |
+
import functools
|
32 |
+
import torchvision.transforms as torch_transforms
|
33 |
+
import copy
|
34 |
+
import itertools
|
35 |
+
from PIL import Image
|
36 |
+
import matplotlib.pyplot as plt
|
37 |
+
import clip
|
38 |
+
from encoding.models.sseg import BaseNet
|
39 |
+
import matplotlib as mpl
|
40 |
+
import matplotlib.colors as mplc
|
41 |
+
import matplotlib.figure as mplfigure
|
42 |
+
import matplotlib.patches as mpatches
|
43 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
44 |
+
from data import get_dataset
|
45 |
+
import torchvision.transforms as transforms
|
46 |
+
|
47 |
+
|
48 |
+
def get_new_pallete(num_cls):
|
49 |
+
n = num_cls
|
50 |
+
pallete = [0]*(n*3)
|
51 |
+
for j in range(0,n):
|
52 |
+
lab = j
|
53 |
+
pallete[j*3+0] = 0
|
54 |
+
pallete[j*3+1] = 0
|
55 |
+
pallete[j*3+2] = 0
|
56 |
+
i = 0
|
57 |
+
while (lab > 0):
|
58 |
+
pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
|
59 |
+
pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
|
60 |
+
pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
|
61 |
+
i = i + 1
|
62 |
+
lab >>= 3
|
63 |
+
return pallete
|
64 |
+
|
65 |
+
def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None):
|
66 |
+
"""Get image color pallete for visualizing masks"""
|
67 |
+
# put colormap
|
68 |
+
out_img = Image.fromarray(npimg.squeeze().astype('uint8'))
|
69 |
+
out_img.putpalette(new_palette)
|
70 |
+
|
71 |
+
if out_label_flag:
|
72 |
+
assert labels is not None
|
73 |
+
u_index = np.unique(npimg)
|
74 |
+
patches = []
|
75 |
+
for i, index in enumerate(u_index):
|
76 |
+
label = labels[index]
|
77 |
+
cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0]
|
78 |
+
red_patch = mpatches.Patch(color=cur_color, label=label)
|
79 |
+
patches.append(red_patch)
|
80 |
+
return out_img, patches
|
81 |
+
|
82 |
+
@st.cache(allow_output_mutation=True)
|
83 |
+
def load_model():
|
84 |
+
class Options:
|
85 |
+
def __init__(self):
|
86 |
+
parser = argparse.ArgumentParser(description="PyTorch Segmentation")
|
87 |
+
# model and dataset
|
88 |
+
parser.add_argument(
|
89 |
+
"--model", type=str, default="encnet", help="model name (default: encnet)"
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--backbone",
|
93 |
+
type=str,
|
94 |
+
default="clip_vitl16_384",
|
95 |
+
help="backbone name (default: resnet50)",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--dataset",
|
99 |
+
type=str,
|
100 |
+
default="ade20k",
|
101 |
+
help="dataset name (default: pascal12)",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--workers", type=int, default=16, metavar="N", help="dataloader threads"
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--base-size", type=int, default=520, help="base image size"
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--crop-size", type=int, default=480, help="crop image size"
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--train-split",
|
114 |
+
type=str,
|
115 |
+
default="train",
|
116 |
+
help="dataset train split (default: train)",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--aux", action="store_true", default=False, help="Auxilary Loss"
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--se-loss",
|
123 |
+
action="store_true",
|
124 |
+
default=False,
|
125 |
+
help="Semantic Encoding Loss SE-loss",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--batch-size",
|
132 |
+
type=int,
|
133 |
+
default=16,
|
134 |
+
metavar="N",
|
135 |
+
help="input batch size for \
|
136 |
+
training (default: auto)",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--test-batch-size",
|
140 |
+
type=int,
|
141 |
+
default=16,
|
142 |
+
metavar="N",
|
143 |
+
help="input batch size for \
|
144 |
+
testing (default: same as batch size)",
|
145 |
+
)
|
146 |
+
# cuda, seed and logging
|
147 |
+
parser.add_argument(
|
148 |
+
"--no-cuda",
|
149 |
+
action="store_true",
|
150 |
+
default=False,
|
151 |
+
help="disables CUDA training",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
|
155 |
+
)
|
156 |
+
# checking point
|
157 |
+
parser.add_argument(
|
158 |
+
"--weights", type=str, default='', help="checkpoint to test"
|
159 |
+
)
|
160 |
+
# evaluation option
|
161 |
+
parser.add_argument(
|
162 |
+
"--eval", action="store_true", default=False, help="evaluating mIoU"
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--export",
|
166 |
+
type=str,
|
167 |
+
default=None,
|
168 |
+
help="put the path to resuming file if needed",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--acc-bn",
|
172 |
+
action="store_true",
|
173 |
+
default=False,
|
174 |
+
help="Re-accumulate BN statistics",
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--test-val",
|
178 |
+
action="store_true",
|
179 |
+
default=False,
|
180 |
+
help="generate masks on val set",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--no-val",
|
184 |
+
action="store_true",
|
185 |
+
default=False,
|
186 |
+
help="skip validation during training",
|
187 |
+
)
|
188 |
+
|
189 |
+
parser.add_argument(
|
190 |
+
"--module",
|
191 |
+
default='lseg',
|
192 |
+
help="select model definition",
|
193 |
+
)
|
194 |
+
|
195 |
+
# test option
|
196 |
+
parser.add_argument(
|
197 |
+
"--data-path", type=str, default='../datasets/', help="path to test image folder"
|
198 |
+
)
|
199 |
+
|
200 |
+
parser.add_argument(
|
201 |
+
"--no-scaleinv",
|
202 |
+
dest="scale_inv",
|
203 |
+
default=True,
|
204 |
+
action="store_false",
|
205 |
+
help="turn off scaleinv layers",
|
206 |
+
)
|
207 |
+
|
208 |
+
parser.add_argument(
|
209 |
+
"--widehead", default=False, action="store_true", help="wider output head"
|
210 |
+
)
|
211 |
+
|
212 |
+
parser.add_argument(
|
213 |
+
"--widehead_hr",
|
214 |
+
default=False,
|
215 |
+
action="store_true",
|
216 |
+
help="wider output head",
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--ignore_index",
|
220 |
+
type=int,
|
221 |
+
default=-1,
|
222 |
+
help="numeric value of ignore label in gt",
|
223 |
+
)
|
224 |
+
|
225 |
+
parser.add_argument(
|
226 |
+
"--label_src",
|
227 |
+
type=str,
|
228 |
+
default="default",
|
229 |
+
help="how to get the labels",
|
230 |
+
)
|
231 |
+
|
232 |
+
parser.add_argument(
|
233 |
+
"--arch_option",
|
234 |
+
type=int,
|
235 |
+
default=0,
|
236 |
+
help="which kind of architecture to be used",
|
237 |
+
)
|
238 |
+
|
239 |
+
parser.add_argument(
|
240 |
+
"--block_depth",
|
241 |
+
type=int,
|
242 |
+
default=0,
|
243 |
+
help="how many blocks should be used",
|
244 |
+
)
|
245 |
+
|
246 |
+
parser.add_argument(
|
247 |
+
"--activation",
|
248 |
+
choices=['lrelu', 'tanh'],
|
249 |
+
default="lrelu",
|
250 |
+
help="use which activation to activate the block",
|
251 |
+
)
|
252 |
+
|
253 |
+
self.parser = parser
|
254 |
+
|
255 |
+
def parse(self):
|
256 |
+
args = self.parser.parse_args(args=[])
|
257 |
+
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
258 |
+
print(args)
|
259 |
+
return args
|
260 |
+
|
261 |
+
args = Options().parse()
|
262 |
+
|
263 |
+
torch.manual_seed(args.seed)
|
264 |
+
args.test_batch_size = 1
|
265 |
+
alpha=0.5
|
266 |
+
|
267 |
+
args.scale_inv = False
|
268 |
+
args.widehead = True
|
269 |
+
args.dataset = 'ade20k'
|
270 |
+
args.backbone = 'clip_vitl16_384'
|
271 |
+
args.weights = 'checkpoints/demo_e200.ckpt'
|
272 |
+
args.ignore_index = 255
|
273 |
+
|
274 |
+
module = LSegModule.load_from_checkpoint(
|
275 |
+
checkpoint_path=args.weights,
|
276 |
+
data_path=args.data_path,
|
277 |
+
dataset=args.dataset,
|
278 |
+
backbone=args.backbone,
|
279 |
+
aux=args.aux,
|
280 |
+
num_features=256,
|
281 |
+
aux_weight=0,
|
282 |
+
se_loss=False,
|
283 |
+
se_weight=0,
|
284 |
+
base_lr=0,
|
285 |
+
batch_size=1,
|
286 |
+
max_epochs=0,
|
287 |
+
ignore_index=args.ignore_index,
|
288 |
+
dropout=0.0,
|
289 |
+
scale_inv=args.scale_inv,
|
290 |
+
augment=False,
|
291 |
+
no_batchnorm=False,
|
292 |
+
widehead=args.widehead,
|
293 |
+
widehead_hr=args.widehead_hr,
|
294 |
+
map_locatin="cpu",
|
295 |
+
arch_option=0,
|
296 |
+
block_depth=0,
|
297 |
+
activation='lrelu',
|
298 |
+
)
|
299 |
+
|
300 |
+
input_transform = module.val_transform
|
301 |
+
|
302 |
+
# dataloader
|
303 |
+
loader_kwargs = (
|
304 |
+
{"num_workers": args.workers, "pin_memory": True} if args.cuda else {}
|
305 |
+
)
|
306 |
+
|
307 |
+
# model
|
308 |
+
if isinstance(module.net, BaseNet):
|
309 |
+
model = module.net
|
310 |
+
else:
|
311 |
+
model = module
|
312 |
+
|
313 |
+
model = model.eval()
|
314 |
+
model = model.cpu()
|
315 |
+
scales = (
|
316 |
+
[0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]
|
317 |
+
if args.dataset == "citys"
|
318 |
+
else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
319 |
+
)
|
320 |
+
|
321 |
+
model.mean = [0.5, 0.5, 0.5]
|
322 |
+
model.std = [0.5, 0.5, 0.5]
|
323 |
+
evaluator = LSeg_MultiEvalModule(
|
324 |
+
model, scales=scales, flip=True
|
325 |
+
).cuda()
|
326 |
+
evaluator.eval()
|
327 |
+
|
328 |
+
transform = transforms.Compose(
|
329 |
+
[
|
330 |
+
transforms.ToTensor(),
|
331 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
332 |
+
transforms.Resize([360,480]),
|
333 |
+
]
|
334 |
+
)
|
335 |
+
|
336 |
+
return evaluator, transform
|
337 |
+
|
338 |
+
"""
|
339 |
+
# LSeg Demo
|
340 |
+
"""
|
341 |
+
lseg_model, lseg_transform = load_model()
|
342 |
+
uploaded_file = st.file_uploader("Choose an image...")
|
343 |
+
input_labels = st.text_input("Input labels", value="dog, grass, other")
|
344 |
+
st.write("The labels are", input_labels)
|
345 |
+
|
346 |
+
if uploaded_file is not None:
|
347 |
+
image = Image.open(uploaded_file)
|
348 |
+
pimage = lseg_transform(np.array(image)).unsqueeze(0)
|
349 |
+
|
350 |
+
labels = []
|
351 |
+
for label in input_labels.split(","):
|
352 |
+
labels.append(label.strip())
|
353 |
+
|
354 |
+
with torch.no_grad():
|
355 |
+
outputs = lseg_model.parallel_forward(pimage, labels)
|
356 |
+
|
357 |
+
predicts = [
|
358 |
+
torch.max(output, 1)[1].cpu().numpy()
|
359 |
+
for output in outputs
|
360 |
+
]
|
361 |
+
|
362 |
+
image = pimage[0].permute(1,2,0)
|
363 |
+
image = image * 0.5 + 0.5
|
364 |
+
image = Image.fromarray(np.uint8(255*image)).convert("RGBA")
|
365 |
+
|
366 |
+
pred = predicts[0]
|
367 |
+
new_palette = get_new_pallete(len(labels))
|
368 |
+
mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels)
|
369 |
+
seg = mask.convert("RGBA")
|
370 |
+
|
371 |
+
fig = plt.figure()
|
372 |
+
plt.subplot(121)
|
373 |
+
plt.imshow(image)
|
374 |
+
plt.axis('off')
|
375 |
+
|
376 |
+
plt.subplot(122)
|
377 |
+
plt.imshow(seg)
|
378 |
+
plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5})
|
379 |
+
plt.axis('off')
|
380 |
+
|
381 |
+
plt.tight_layout()
|
382 |
+
|
383 |
+
#st.image([image,seg], width=700, caption=["Input image", "Segmentation"])
|
384 |
+
st.pyplot(fig)
|
385 |
+
|
386 |
+
|
modules/lseg_module.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from .lsegmentation_module import LSegmentationModule
|
8 |
+
from .models.lseg_net import LSegNet
|
9 |
+
from encoding.models.sseg.base import up_kwargs
|
10 |
+
|
11 |
+
import os
|
12 |
+
import clip
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from scipy import signal
|
16 |
+
import glob
|
17 |
+
|
18 |
+
from PIL import Image
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import pandas as pd
|
21 |
+
|
22 |
+
|
23 |
+
class LSegModule(LSegmentationModule):
|
24 |
+
def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
|
25 |
+
super(LSegModule, self).__init__(
|
26 |
+
data_path, dataset, batch_size, base_lr, max_epochs, **kwargs
|
27 |
+
)
|
28 |
+
|
29 |
+
if dataset == "citys":
|
30 |
+
self.base_size = 2048
|
31 |
+
self.crop_size = 768
|
32 |
+
else:
|
33 |
+
self.base_size = 520
|
34 |
+
self.crop_size = 480
|
35 |
+
|
36 |
+
use_pretrained = True
|
37 |
+
norm_mean= [0.5, 0.5, 0.5]
|
38 |
+
norm_std = [0.5, 0.5, 0.5]
|
39 |
+
|
40 |
+
print('** Use norm {}, {} as the mean and std **'.format(norm_mean, norm_std))
|
41 |
+
|
42 |
+
train_transform = [
|
43 |
+
transforms.ToTensor(),
|
44 |
+
transforms.Normalize(norm_mean, norm_std),
|
45 |
+
]
|
46 |
+
|
47 |
+
val_transform = [
|
48 |
+
transforms.ToTensor(),
|
49 |
+
transforms.Normalize(norm_mean, norm_std),
|
50 |
+
]
|
51 |
+
|
52 |
+
self.train_transform = transforms.Compose(train_transform)
|
53 |
+
self.val_transform = transforms.Compose(val_transform)
|
54 |
+
|
55 |
+
self.trainset = self.get_trainset(
|
56 |
+
dataset,
|
57 |
+
augment=kwargs["augment"],
|
58 |
+
base_size=self.base_size,
|
59 |
+
crop_size=self.crop_size,
|
60 |
+
)
|
61 |
+
|
62 |
+
self.valset = self.get_valset(
|
63 |
+
dataset,
|
64 |
+
augment=kwargs["augment"],
|
65 |
+
base_size=self.base_size,
|
66 |
+
crop_size=self.crop_size,
|
67 |
+
)
|
68 |
+
|
69 |
+
use_batchnorm = (
|
70 |
+
(not kwargs["no_batchnorm"]) if "no_batchnorm" in kwargs else True
|
71 |
+
)
|
72 |
+
# print(kwargs)
|
73 |
+
|
74 |
+
labels = self.get_labels('ade20k')
|
75 |
+
|
76 |
+
self.net = LSegNet(
|
77 |
+
labels=labels,
|
78 |
+
backbone=kwargs["backbone"],
|
79 |
+
features=kwargs["num_features"],
|
80 |
+
crop_size=self.crop_size,
|
81 |
+
arch_option=kwargs["arch_option"],
|
82 |
+
block_depth=kwargs["block_depth"],
|
83 |
+
activation=kwargs["activation"],
|
84 |
+
)
|
85 |
+
|
86 |
+
self.net.pretrained.model.patch_embed.img_size = (
|
87 |
+
self.crop_size,
|
88 |
+
self.crop_size,
|
89 |
+
)
|
90 |
+
|
91 |
+
self._up_kwargs = up_kwargs
|
92 |
+
self.mean = norm_mean
|
93 |
+
self.std = norm_std
|
94 |
+
|
95 |
+
self.criterion = self.get_criterion(**kwargs)
|
96 |
+
|
97 |
+
def get_labels(self, dataset):
|
98 |
+
labels = []
|
99 |
+
path = 'label_files/{}_objectInfo150.txt'.format(dataset)
|
100 |
+
assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path)
|
101 |
+
f = open(path, 'r')
|
102 |
+
lines = f.readlines()
|
103 |
+
for line in lines:
|
104 |
+
label = line.strip().split(',')[-1].split(';')[0]
|
105 |
+
labels.append(label)
|
106 |
+
f.close()
|
107 |
+
if dataset in ['ade20k']:
|
108 |
+
labels = labels[1:]
|
109 |
+
return labels
|
110 |
+
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def add_model_specific_args(parent_parser):
|
114 |
+
parser = LSegmentationModule.add_model_specific_args(parent_parser)
|
115 |
+
parser = ArgumentParser(parents=[parser])
|
116 |
+
|
117 |
+
parser.add_argument(
|
118 |
+
"--backbone",
|
119 |
+
type=str,
|
120 |
+
default="clip_vitl16_384",
|
121 |
+
help="backbone network",
|
122 |
+
)
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--num_features",
|
126 |
+
type=int,
|
127 |
+
default=256,
|
128 |
+
help="number of featurs that go from encoder to decoder",
|
129 |
+
)
|
130 |
+
|
131 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
"--finetune_weights", type=str, help="load weights to finetune from"
|
135 |
+
)
|
136 |
+
|
137 |
+
parser.add_argument(
|
138 |
+
"--no-scaleinv",
|
139 |
+
default=True,
|
140 |
+
action="store_false",
|
141 |
+
help="turn off scaleinv layers",
|
142 |
+
)
|
143 |
+
|
144 |
+
parser.add_argument(
|
145 |
+
"--no-batchnorm",
|
146 |
+
default=False,
|
147 |
+
action="store_true",
|
148 |
+
help="turn off batchnorm",
|
149 |
+
)
|
150 |
+
|
151 |
+
parser.add_argument(
|
152 |
+
"--widehead", default=False, action="store_true", help="wider output head"
|
153 |
+
)
|
154 |
+
|
155 |
+
parser.add_argument(
|
156 |
+
"--widehead_hr",
|
157 |
+
default=False,
|
158 |
+
action="store_true",
|
159 |
+
help="wider output head",
|
160 |
+
)
|
161 |
+
|
162 |
+
parser.add_argument(
|
163 |
+
"--arch_option",
|
164 |
+
type=int,
|
165 |
+
default=0,
|
166 |
+
help="which kind of architecture to be used",
|
167 |
+
)
|
168 |
+
|
169 |
+
parser.add_argument(
|
170 |
+
"--block_depth",
|
171 |
+
type=int,
|
172 |
+
default=0,
|
173 |
+
help="how many blocks should be used",
|
174 |
+
)
|
175 |
+
|
176 |
+
parser.add_argument(
|
177 |
+
"--activation",
|
178 |
+
choices=['lrelu', 'tanh'],
|
179 |
+
default="lrelu",
|
180 |
+
help="use which activation to activate the block",
|
181 |
+
)
|
182 |
+
|
183 |
+
return parser
|
modules/lsegmentation_module.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import clip
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
|
9 |
+
from argparse import ArgumentParser
|
10 |
+
|
11 |
+
import pytorch_lightning as pl
|
12 |
+
|
13 |
+
from data import get_dataset, get_available_datasets
|
14 |
+
|
15 |
+
from encoding.models import get_segmentation_model
|
16 |
+
from encoding.nn import SegmentationLosses
|
17 |
+
|
18 |
+
from encoding.utils import batch_pix_accuracy, batch_intersection_union
|
19 |
+
|
20 |
+
# add mixed precision
|
21 |
+
import torch.cuda.amp as amp
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from encoding.utils import SegmentationMetric
|
25 |
+
|
26 |
+
class LSegmentationModule(pl.LightningModule):
|
27 |
+
def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.data_path = data_path
|
31 |
+
self.batch_size = batch_size
|
32 |
+
self.base_lr = base_lr / 16 * batch_size
|
33 |
+
self.lr = self.base_lr
|
34 |
+
|
35 |
+
self.epochs = max_epochs
|
36 |
+
self.other_kwargs = kwargs
|
37 |
+
self.enabled = False #True mixed precision will make things complicated and leading to NAN error
|
38 |
+
self.scaler = amp.GradScaler(enabled=self.enabled)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.net(x)
|
42 |
+
|
43 |
+
def evaluate(self, x, target=None):
|
44 |
+
pred = self.net.forward(x)
|
45 |
+
if isinstance(pred, (tuple, list)):
|
46 |
+
pred = pred[0]
|
47 |
+
if target is None:
|
48 |
+
return pred
|
49 |
+
correct, labeled = batch_pix_accuracy(pred.data, target.data)
|
50 |
+
inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
|
51 |
+
|
52 |
+
return correct, labeled, inter, union
|
53 |
+
|
54 |
+
def evaluate_random(self, x, labelset, target=None):
|
55 |
+
pred = self.net.forward(x, labelset)
|
56 |
+
if isinstance(pred, (tuple, list)):
|
57 |
+
pred = pred[0]
|
58 |
+
if target is None:
|
59 |
+
return pred
|
60 |
+
correct, labeled = batch_pix_accuracy(pred.data, target.data)
|
61 |
+
inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
|
62 |
+
|
63 |
+
return correct, labeled, inter, union
|
64 |
+
|
65 |
+
|
66 |
+
def training_step(self, batch, batch_nb):
|
67 |
+
img, target = batch
|
68 |
+
with amp.autocast(enabled=self.enabled):
|
69 |
+
out = self(img)
|
70 |
+
multi_loss = isinstance(out, tuple)
|
71 |
+
if multi_loss:
|
72 |
+
loss = self.criterion(*out, target)
|
73 |
+
else:
|
74 |
+
loss = self.criterion(out, target)
|
75 |
+
loss = self.scaler.scale(loss)
|
76 |
+
final_output = out[0] if multi_loss else out
|
77 |
+
train_pred, train_gt = self._filter_invalid(final_output, target)
|
78 |
+
if train_gt.nelement() != 0:
|
79 |
+
self.train_accuracy(train_pred, train_gt)
|
80 |
+
self.log("train_loss", loss)
|
81 |
+
return loss
|
82 |
+
|
83 |
+
def training_epoch_end(self, outs):
|
84 |
+
self.log("train_acc_epoch", self.train_accuracy.compute())
|
85 |
+
|
86 |
+
def validation_step(self, batch, batch_nb):
|
87 |
+
img, target = batch
|
88 |
+
out = self(img)
|
89 |
+
multi_loss = isinstance(out, tuple)
|
90 |
+
if multi_loss:
|
91 |
+
val_loss = self.criterion(*out, target)
|
92 |
+
else:
|
93 |
+
val_loss = self.criterion(out, target)
|
94 |
+
final_output = out[0] if multi_loss else out
|
95 |
+
valid_pred, valid_gt = self._filter_invalid(final_output, target)
|
96 |
+
self.val_iou.update(target, final_output)
|
97 |
+
pixAcc, iou = self.val_iou.get()
|
98 |
+
self.log("val_loss_step", val_loss)
|
99 |
+
self.log("pix_acc_step", pixAcc)
|
100 |
+
self.log(
|
101 |
+
"val_acc_step",
|
102 |
+
self.val_accuracy(valid_pred, valid_gt),
|
103 |
+
)
|
104 |
+
self.log("val_iou", iou)
|
105 |
+
|
106 |
+
def validation_epoch_end(self, outs):
|
107 |
+
pixAcc, iou = self.val_iou.get()
|
108 |
+
self.log("val_acc_epoch", self.val_accuracy.compute())
|
109 |
+
self.log("val_iou_epoch", iou)
|
110 |
+
self.log("pix_acc_epoch", pixAcc)
|
111 |
+
|
112 |
+
self.val_iou.reset()
|
113 |
+
|
114 |
+
def _filter_invalid(self, pred, target):
|
115 |
+
valid = target != self.other_kwargs["ignore_index"]
|
116 |
+
_, mx = torch.max(pred, dim=1)
|
117 |
+
return mx[valid], target[valid]
|
118 |
+
|
119 |
+
def configure_optimizers(self):
|
120 |
+
params_list = [
|
121 |
+
{"params": self.net.pretrained.parameters(), "lr": self.base_lr},
|
122 |
+
]
|
123 |
+
if hasattr(self.net, "scratch"):
|
124 |
+
print("Found output scratch")
|
125 |
+
params_list.append(
|
126 |
+
{"params": self.net.scratch.parameters(), "lr": self.base_lr * 10}
|
127 |
+
)
|
128 |
+
if hasattr(self.net, "auxlayer"):
|
129 |
+
print("Found auxlayer")
|
130 |
+
params_list.append(
|
131 |
+
{"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10}
|
132 |
+
)
|
133 |
+
if hasattr(self.net, "scale_inv_conv"):
|
134 |
+
print(self.net.scale_inv_conv)
|
135 |
+
print("Found scaleinv layers")
|
136 |
+
params_list.append(
|
137 |
+
{
|
138 |
+
"params": self.net.scale_inv_conv.parameters(),
|
139 |
+
"lr": self.base_lr * 10,
|
140 |
+
}
|
141 |
+
)
|
142 |
+
params_list.append(
|
143 |
+
{"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10}
|
144 |
+
)
|
145 |
+
params_list.append(
|
146 |
+
{"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10}
|
147 |
+
)
|
148 |
+
params_list.append(
|
149 |
+
{"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10}
|
150 |
+
)
|
151 |
+
|
152 |
+
if self.other_kwargs["midasproto"]:
|
153 |
+
print("Using midas optimization protocol")
|
154 |
+
|
155 |
+
opt = torch.optim.Adam(
|
156 |
+
params_list,
|
157 |
+
lr=self.base_lr,
|
158 |
+
betas=(0.9, 0.999),
|
159 |
+
weight_decay=self.other_kwargs["weight_decay"],
|
160 |
+
)
|
161 |
+
sch = torch.optim.lr_scheduler.LambdaLR(
|
162 |
+
opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
|
163 |
+
)
|
164 |
+
|
165 |
+
else:
|
166 |
+
opt = torch.optim.SGD(
|
167 |
+
params_list,
|
168 |
+
lr=self.base_lr,
|
169 |
+
momentum=0.9,
|
170 |
+
weight_decay=self.other_kwargs["weight_decay"],
|
171 |
+
)
|
172 |
+
sch = torch.optim.lr_scheduler.LambdaLR(
|
173 |
+
opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
|
174 |
+
)
|
175 |
+
return [opt], [sch]
|
176 |
+
|
177 |
+
def train_dataloader(self):
|
178 |
+
return torch.utils.data.DataLoader(
|
179 |
+
self.trainset,
|
180 |
+
batch_size=self.batch_size,
|
181 |
+
shuffle=True,
|
182 |
+
num_workers=16,
|
183 |
+
worker_init_fn=lambda x: random.seed(time.time() + x),
|
184 |
+
)
|
185 |
+
|
186 |
+
def val_dataloader(self):
|
187 |
+
return torch.utils.data.DataLoader(
|
188 |
+
self.valset,
|
189 |
+
batch_size=self.batch_size,
|
190 |
+
shuffle=False,
|
191 |
+
num_workers=16,
|
192 |
+
)
|
193 |
+
|
194 |
+
def get_trainset(self, dset, augment=False, **kwargs):
|
195 |
+
print(kwargs)
|
196 |
+
if augment == True:
|
197 |
+
mode = "train_x"
|
198 |
+
else:
|
199 |
+
mode = "train"
|
200 |
+
|
201 |
+
print(mode)
|
202 |
+
dset = get_dataset(
|
203 |
+
dset,
|
204 |
+
root=self.data_path,
|
205 |
+
split="train",
|
206 |
+
mode=mode,
|
207 |
+
transform=self.train_transform,
|
208 |
+
**kwargs
|
209 |
+
)
|
210 |
+
|
211 |
+
self.num_classes = dset.num_class
|
212 |
+
self.train_accuracy = pl.metrics.Accuracy()
|
213 |
+
|
214 |
+
return dset
|
215 |
+
|
216 |
+
def get_valset(self, dset, augment=False, **kwargs):
|
217 |
+
self.val_accuracy = pl.metrics.Accuracy()
|
218 |
+
self.val_iou = SegmentationMetric(self.num_classes)
|
219 |
+
|
220 |
+
if augment == True:
|
221 |
+
mode = "val_x"
|
222 |
+
else:
|
223 |
+
mode = "val"
|
224 |
+
|
225 |
+
print(mode)
|
226 |
+
return get_dataset(
|
227 |
+
dset,
|
228 |
+
root=self.data_path,
|
229 |
+
split="val",
|
230 |
+
mode=mode,
|
231 |
+
transform=self.val_transform,
|
232 |
+
**kwargs
|
233 |
+
)
|
234 |
+
|
235 |
+
|
236 |
+
def get_criterion(self, **kwargs):
|
237 |
+
return SegmentationLosses(
|
238 |
+
se_loss=kwargs["se_loss"],
|
239 |
+
aux=kwargs["aux"],
|
240 |
+
nclass=self.num_classes,
|
241 |
+
se_weight=kwargs["se_weight"],
|
242 |
+
aux_weight=kwargs["aux_weight"],
|
243 |
+
ignore_index=kwargs["ignore_index"],
|
244 |
+
)
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def add_model_specific_args(parent_parser):
|
248 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
249 |
+
parser.add_argument(
|
250 |
+
"--data_path", type=str, help="path where dataset is stored"
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--dataset",
|
254 |
+
choices=get_available_datasets(),
|
255 |
+
default="ade20k",
|
256 |
+
help="dataset to train on",
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--batch_size", type=int, default=16, help="size of the batches"
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--base_lr", type=float, default=0.004, help="learning rate"
|
263 |
+
)
|
264 |
+
parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum")
|
265 |
+
parser.add_argument(
|
266 |
+
"--weight_decay", type=float, default=1e-4, help="weight_decay"
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"--aux", action="store_true", default=False, help="Auxilary Loss"
|
270 |
+
)
|
271 |
+
parser.add_argument(
|
272 |
+
"--aux-weight",
|
273 |
+
type=float,
|
274 |
+
default=0.2,
|
275 |
+
help="Auxilary loss weight (default: 0.2)",
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--se-loss",
|
279 |
+
action="store_true",
|
280 |
+
default=False,
|
281 |
+
help="Semantic Encoding Loss SE-loss",
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
|
285 |
+
)
|
286 |
+
|
287 |
+
parser.add_argument(
|
288 |
+
"--midasproto", action="store_true", default=False, help="midasprotocol"
|
289 |
+
)
|
290 |
+
|
291 |
+
parser.add_argument(
|
292 |
+
"--ignore_index",
|
293 |
+
type=int,
|
294 |
+
default=-1,
|
295 |
+
help="numeric value of ignore label in gt",
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--augment",
|
299 |
+
action="store_true",
|
300 |
+
default=False,
|
301 |
+
help="Use extended augmentations",
|
302 |
+
)
|
303 |
+
|
304 |
+
return parser
|
modules/models/lseg_blocks.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .lseg_vit import (
|
5 |
+
_make_pretrained_clip_vitl16_384,
|
6 |
+
_make_pretrained_clip_vitb32_384,
|
7 |
+
_make_pretrained_clipRN50x16_vitl16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def _make_encoder(
|
13 |
+
backbone,
|
14 |
+
features,
|
15 |
+
use_pretrained=True,
|
16 |
+
groups=1,
|
17 |
+
expand=False,
|
18 |
+
exportable=True,
|
19 |
+
hooks=None,
|
20 |
+
use_vit_only=False,
|
21 |
+
use_readout="ignore",
|
22 |
+
enable_attention_hooks=False,
|
23 |
+
):
|
24 |
+
if backbone == "clip_vitl16_384":
|
25 |
+
clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384(
|
26 |
+
use_pretrained,
|
27 |
+
hooks=hooks,
|
28 |
+
use_readout=use_readout,
|
29 |
+
enable_attention_hooks=enable_attention_hooks,
|
30 |
+
)
|
31 |
+
scratch = _make_scratch(
|
32 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
33 |
+
)
|
34 |
+
elif backbone == "clipRN50x16_vitl16_384":
|
35 |
+
clip_pretrained, pretrained = _make_pretrained_clipRN50x16_vitl16_384(
|
36 |
+
use_pretrained,
|
37 |
+
hooks=hooks,
|
38 |
+
use_readout=use_readout,
|
39 |
+
enable_attention_hooks=enable_attention_hooks,
|
40 |
+
)
|
41 |
+
scratch = _make_scratch(
|
42 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
43 |
+
)
|
44 |
+
elif backbone == "clip_vitb32_384":
|
45 |
+
clip_pretrained, pretrained = _make_pretrained_clip_vitb32_384(
|
46 |
+
use_pretrained,
|
47 |
+
hooks=hooks,
|
48 |
+
use_readout=use_readout,
|
49 |
+
)
|
50 |
+
scratch = _make_scratch(
|
51 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
print(f"Backbone '{backbone}' not implemented")
|
55 |
+
assert False
|
56 |
+
|
57 |
+
return clip_pretrained, pretrained, scratch
|
58 |
+
|
59 |
+
|
60 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
61 |
+
scratch = nn.Module()
|
62 |
+
|
63 |
+
out_shape1 = out_shape
|
64 |
+
out_shape2 = out_shape
|
65 |
+
out_shape3 = out_shape
|
66 |
+
out_shape4 = out_shape
|
67 |
+
if expand == True:
|
68 |
+
out_shape1 = out_shape
|
69 |
+
out_shape2 = out_shape * 2
|
70 |
+
out_shape3 = out_shape * 4
|
71 |
+
out_shape4 = out_shape * 8
|
72 |
+
|
73 |
+
scratch.layer1_rn = nn.Conv2d(
|
74 |
+
in_shape[0],
|
75 |
+
out_shape1,
|
76 |
+
kernel_size=3,
|
77 |
+
stride=1,
|
78 |
+
padding=1,
|
79 |
+
bias=False,
|
80 |
+
groups=groups,
|
81 |
+
)
|
82 |
+
scratch.layer2_rn = nn.Conv2d(
|
83 |
+
in_shape[1],
|
84 |
+
out_shape2,
|
85 |
+
kernel_size=3,
|
86 |
+
stride=1,
|
87 |
+
padding=1,
|
88 |
+
bias=False,
|
89 |
+
groups=groups,
|
90 |
+
)
|
91 |
+
scratch.layer3_rn = nn.Conv2d(
|
92 |
+
in_shape[2],
|
93 |
+
out_shape3,
|
94 |
+
kernel_size=3,
|
95 |
+
stride=1,
|
96 |
+
padding=1,
|
97 |
+
bias=False,
|
98 |
+
groups=groups,
|
99 |
+
)
|
100 |
+
scratch.layer4_rn = nn.Conv2d(
|
101 |
+
in_shape[3],
|
102 |
+
out_shape4,
|
103 |
+
kernel_size=3,
|
104 |
+
stride=1,
|
105 |
+
padding=1,
|
106 |
+
bias=False,
|
107 |
+
groups=groups,
|
108 |
+
)
|
109 |
+
|
110 |
+
return scratch
|
111 |
+
|
112 |
+
|
113 |
+
class Interpolate(nn.Module):
|
114 |
+
"""Interpolation module."""
|
115 |
+
|
116 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
117 |
+
"""Init.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
scale_factor (float): scaling
|
121 |
+
mode (str): interpolation mode
|
122 |
+
"""
|
123 |
+
super(Interpolate, self).__init__()
|
124 |
+
|
125 |
+
self.interp = nn.functional.interpolate
|
126 |
+
self.scale_factor = scale_factor
|
127 |
+
self.mode = mode
|
128 |
+
self.align_corners = align_corners
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
"""Forward pass.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
x (tensor): input
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
tensor: interpolated data
|
138 |
+
"""
|
139 |
+
|
140 |
+
x = self.interp(
|
141 |
+
x,
|
142 |
+
scale_factor=self.scale_factor,
|
143 |
+
mode=self.mode,
|
144 |
+
align_corners=self.align_corners,
|
145 |
+
)
|
146 |
+
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class ResidualConvUnit(nn.Module):
|
151 |
+
"""Residual convolution module."""
|
152 |
+
|
153 |
+
def __init__(self, features):
|
154 |
+
"""Init.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
features (int): number of features
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
|
161 |
+
self.conv1 = nn.Conv2d(
|
162 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
163 |
+
)
|
164 |
+
|
165 |
+
self.conv2 = nn.Conv2d(
|
166 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
167 |
+
)
|
168 |
+
|
169 |
+
self.relu = nn.ReLU(inplace=True)
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
"""Forward pass.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
x (tensor): input
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
tensor: output
|
179 |
+
"""
|
180 |
+
out = self.relu(x)
|
181 |
+
out = self.conv1(out)
|
182 |
+
out = self.relu(out)
|
183 |
+
out = self.conv2(out)
|
184 |
+
|
185 |
+
return out + x
|
186 |
+
|
187 |
+
|
188 |
+
class FeatureFusionBlock(nn.Module):
|
189 |
+
"""Feature fusion block."""
|
190 |
+
|
191 |
+
def __init__(self, features):
|
192 |
+
"""Init.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
features (int): number of features
|
196 |
+
"""
|
197 |
+
super(FeatureFusionBlock, self).__init__()
|
198 |
+
|
199 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
200 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
201 |
+
|
202 |
+
def forward(self, *xs):
|
203 |
+
"""Forward pass.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
tensor: output
|
207 |
+
"""
|
208 |
+
output = xs[0]
|
209 |
+
|
210 |
+
if len(xs) == 2:
|
211 |
+
output += self.resConfUnit1(xs[1])
|
212 |
+
|
213 |
+
output = self.resConfUnit2(output)
|
214 |
+
|
215 |
+
output = nn.functional.interpolate(
|
216 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
217 |
+
)
|
218 |
+
|
219 |
+
return output
|
220 |
+
|
221 |
+
|
222 |
+
class ResidualConvUnit_custom(nn.Module):
|
223 |
+
"""Residual convolution module."""
|
224 |
+
|
225 |
+
def __init__(self, features, activation, bn):
|
226 |
+
"""Init.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
features (int): number of features
|
230 |
+
"""
|
231 |
+
super().__init__()
|
232 |
+
|
233 |
+
self.bn = bn
|
234 |
+
|
235 |
+
self.groups = 1
|
236 |
+
|
237 |
+
self.conv1 = nn.Conv2d(
|
238 |
+
features,
|
239 |
+
features,
|
240 |
+
kernel_size=3,
|
241 |
+
stride=1,
|
242 |
+
padding=1,
|
243 |
+
bias=not self.bn,
|
244 |
+
groups=self.groups,
|
245 |
+
)
|
246 |
+
|
247 |
+
self.conv2 = nn.Conv2d(
|
248 |
+
features,
|
249 |
+
features,
|
250 |
+
kernel_size=3,
|
251 |
+
stride=1,
|
252 |
+
padding=1,
|
253 |
+
bias=not self.bn,
|
254 |
+
groups=self.groups,
|
255 |
+
)
|
256 |
+
|
257 |
+
if self.bn == True:
|
258 |
+
self.bn1 = nn.BatchNorm2d(features)
|
259 |
+
self.bn2 = nn.BatchNorm2d(features)
|
260 |
+
|
261 |
+
self.activation = activation
|
262 |
+
|
263 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
"""Forward pass.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
x (tensor): input
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
tensor: output
|
273 |
+
"""
|
274 |
+
|
275 |
+
out = self.activation(x)
|
276 |
+
out = self.conv1(out)
|
277 |
+
if self.bn == True:
|
278 |
+
out = self.bn1(out)
|
279 |
+
|
280 |
+
out = self.activation(out)
|
281 |
+
out = self.conv2(out)
|
282 |
+
if self.bn == True:
|
283 |
+
out = self.bn2(out)
|
284 |
+
|
285 |
+
if self.groups > 1:
|
286 |
+
out = self.conv_merge(out)
|
287 |
+
|
288 |
+
return self.skip_add.add(out, x)
|
289 |
+
|
290 |
+
# return out + x
|
291 |
+
|
292 |
+
|
293 |
+
class FeatureFusionBlock_custom(nn.Module):
|
294 |
+
"""Feature fusion block."""
|
295 |
+
|
296 |
+
def __init__(
|
297 |
+
self,
|
298 |
+
features,
|
299 |
+
activation,
|
300 |
+
deconv=False,
|
301 |
+
bn=False,
|
302 |
+
expand=False,
|
303 |
+
align_corners=True,
|
304 |
+
):
|
305 |
+
"""Init.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
features (int): number of features
|
309 |
+
"""
|
310 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
311 |
+
|
312 |
+
self.deconv = deconv
|
313 |
+
self.align_corners = align_corners
|
314 |
+
|
315 |
+
self.groups = 1
|
316 |
+
|
317 |
+
self.expand = expand
|
318 |
+
out_features = features
|
319 |
+
if self.expand == True:
|
320 |
+
out_features = features // 2
|
321 |
+
|
322 |
+
self.out_conv = nn.Conv2d(
|
323 |
+
features,
|
324 |
+
out_features,
|
325 |
+
kernel_size=1,
|
326 |
+
stride=1,
|
327 |
+
padding=0,
|
328 |
+
bias=True,
|
329 |
+
groups=1,
|
330 |
+
)
|
331 |
+
|
332 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
333 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
334 |
+
|
335 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
336 |
+
|
337 |
+
def forward(self, *xs):
|
338 |
+
"""Forward pass.
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
tensor: output
|
342 |
+
"""
|
343 |
+
output = xs[0]
|
344 |
+
|
345 |
+
if len(xs) == 2:
|
346 |
+
res = self.resConfUnit1(xs[1])
|
347 |
+
output = self.skip_add.add(output, res)
|
348 |
+
# output += res
|
349 |
+
|
350 |
+
output = self.resConfUnit2(output)
|
351 |
+
|
352 |
+
output = nn.functional.interpolate(
|
353 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
354 |
+
)
|
355 |
+
|
356 |
+
output = self.out_conv(output)
|
357 |
+
|
358 |
+
return output
|
359 |
+
|
modules/models/lseg_net.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import types
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .lseg_blocks import FeatureFusionBlock, Interpolate, _make_encoder, FeatureFusionBlock_custom, forward_vit
|
9 |
+
import clip
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import os
|
13 |
+
|
14 |
+
class depthwise_clipseg_conv(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(depthwise_clipseg_conv, self).__init__()
|
17 |
+
self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1)
|
18 |
+
|
19 |
+
def depthwise_clipseg(self, x, channels):
|
20 |
+
x = torch.cat([self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], dim=1)
|
21 |
+
return x
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
channels = x.shape[1]
|
25 |
+
out = self.depthwise_clipseg(x, channels)
|
26 |
+
return out
|
27 |
+
|
28 |
+
|
29 |
+
class depthwise_conv(nn.Module):
|
30 |
+
def __init__(self, kernel_size=3, stride=1, padding=1):
|
31 |
+
super(depthwise_conv, self).__init__()
|
32 |
+
self.depthwise = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
# support for 4D tensor with NCHW
|
36 |
+
C, H, W = x.shape[1:]
|
37 |
+
x = x.reshape(-1, 1, H, W)
|
38 |
+
x = self.depthwise(x)
|
39 |
+
x = x.view(-1, C, H, W)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class depthwise_block(nn.Module):
|
44 |
+
def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
|
45 |
+
super(depthwise_block, self).__init__()
|
46 |
+
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
|
47 |
+
if activation == 'relu':
|
48 |
+
self.activation = nn.ReLU()
|
49 |
+
elif activation == 'lrelu':
|
50 |
+
self.activation = nn.LeakyReLU()
|
51 |
+
elif activation == 'tanh':
|
52 |
+
self.activation = nn.Tanh()
|
53 |
+
|
54 |
+
def forward(self, x, act=True):
|
55 |
+
x = self.depthwise(x)
|
56 |
+
if act:
|
57 |
+
x = self.activation(x)
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class bottleneck_block(nn.Module):
|
62 |
+
def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
|
63 |
+
super(bottleneck_block, self).__init__()
|
64 |
+
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
|
65 |
+
if activation == 'relu':
|
66 |
+
self.activation = nn.ReLU()
|
67 |
+
elif activation == 'lrelu':
|
68 |
+
self.activation = nn.LeakyReLU()
|
69 |
+
elif activation == 'tanh':
|
70 |
+
self.activation = nn.Tanh()
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x, act=True):
|
74 |
+
sum_layer = x.max(dim=1, keepdim=True)[0]
|
75 |
+
x = self.depthwise(x)
|
76 |
+
x = x + sum_layer
|
77 |
+
if act:
|
78 |
+
x = self.activation(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
class BaseModel(torch.nn.Module):
|
82 |
+
def load(self, path):
|
83 |
+
"""Load model from file.
|
84 |
+
Args:
|
85 |
+
path (str): file path
|
86 |
+
"""
|
87 |
+
parameters = torch.load(path, map_location=torch.device("cpu"))
|
88 |
+
|
89 |
+
if "optimizer" in parameters:
|
90 |
+
parameters = parameters["model"]
|
91 |
+
|
92 |
+
self.load_state_dict(parameters)
|
93 |
+
|
94 |
+
def _make_fusion_block(features, use_bn):
|
95 |
+
return FeatureFusionBlock_custom(
|
96 |
+
features,
|
97 |
+
activation=nn.ReLU(False),
|
98 |
+
deconv=False,
|
99 |
+
bn=use_bn,
|
100 |
+
expand=False,
|
101 |
+
align_corners=True,
|
102 |
+
)
|
103 |
+
|
104 |
+
class LSeg(BaseModel):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
head,
|
108 |
+
features=256,
|
109 |
+
backbone="clip_vitl16_384",
|
110 |
+
readout="project",
|
111 |
+
channels_last=False,
|
112 |
+
use_bn=False,
|
113 |
+
**kwargs,
|
114 |
+
):
|
115 |
+
super(LSeg, self).__init__()
|
116 |
+
|
117 |
+
self.channels_last = channels_last
|
118 |
+
|
119 |
+
hooks = {
|
120 |
+
"clip_vitl16_384": [5, 11, 17, 23],
|
121 |
+
"clipRN50x16_vitl16_384": [5, 11, 17, 23],
|
122 |
+
"clip_vitb32_384": [2, 5, 8, 11],
|
123 |
+
}
|
124 |
+
|
125 |
+
# Instantiate backbone and reassemble blocks
|
126 |
+
self.clip_pretrained, self.pretrained, self.scratch = _make_encoder(
|
127 |
+
backbone,
|
128 |
+
features,
|
129 |
+
groups=1,
|
130 |
+
expand=False,
|
131 |
+
exportable=False,
|
132 |
+
hooks=hooks[backbone],
|
133 |
+
use_readout=readout,
|
134 |
+
)
|
135 |
+
|
136 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
137 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
138 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
139 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
140 |
+
|
141 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp()
|
142 |
+
if backbone in ["clipRN50x16_vitl16_384"]:
|
143 |
+
self.out_c = 768
|
144 |
+
else:
|
145 |
+
self.out_c = 512
|
146 |
+
self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1)
|
147 |
+
|
148 |
+
self.arch_option = kwargs["arch_option"]
|
149 |
+
if self.arch_option == 1:
|
150 |
+
self.scratch.head_block = bottleneck_block(activation=kwargs["activation"])
|
151 |
+
self.block_depth = kwargs['block_depth']
|
152 |
+
elif self.arch_option == 2:
|
153 |
+
self.scratch.head_block = depthwise_block(activation=kwargs["activation"])
|
154 |
+
self.block_depth = kwargs['block_depth']
|
155 |
+
|
156 |
+
self.scratch.output_conv = head
|
157 |
+
|
158 |
+
self.text = clip.tokenize(self.labels)
|
159 |
+
|
160 |
+
def forward(self, x, labelset=''):
|
161 |
+
if labelset == '':
|
162 |
+
text = self.text
|
163 |
+
else:
|
164 |
+
text = clip.tokenize(labelset)
|
165 |
+
|
166 |
+
if self.channels_last == True:
|
167 |
+
x.contiguous(memory_format=torch.channels_last)
|
168 |
+
|
169 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
170 |
+
|
171 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
172 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
173 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
174 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
175 |
+
|
176 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
177 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
178 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
179 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
180 |
+
|
181 |
+
text = text.to(x.device)
|
182 |
+
self.logit_scale = self.logit_scale.to(x.device)
|
183 |
+
text_features = self.clip_pretrained.encode_text(text)
|
184 |
+
|
185 |
+
image_features = self.scratch.head1(path_1)
|
186 |
+
|
187 |
+
imshape = image_features.shape
|
188 |
+
image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c)
|
189 |
+
|
190 |
+
# normalized features
|
191 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
192 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
193 |
+
|
194 |
+
logits_per_image = self.logit_scale * image_features.half() @ text_features.t()
|
195 |
+
|
196 |
+
out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2)
|
197 |
+
|
198 |
+
if self.arch_option in [1, 2]:
|
199 |
+
for _ in range(self.block_depth - 1):
|
200 |
+
out = self.scratch.head_block(out)
|
201 |
+
out = self.scratch.head_block(out, False)
|
202 |
+
|
203 |
+
out = self.scratch.output_conv(out)
|
204 |
+
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
class LSegNet(LSeg):
|
209 |
+
"""Network for semantic segmentation."""
|
210 |
+
def __init__(self, labels, path=None, scale_factor=0.5, crop_size=480, **kwargs):
|
211 |
+
|
212 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
213 |
+
kwargs["use_bn"] = True
|
214 |
+
|
215 |
+
self.crop_size = crop_size
|
216 |
+
self.scale_factor = scale_factor
|
217 |
+
self.labels = labels
|
218 |
+
|
219 |
+
head = nn.Sequential(
|
220 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
221 |
+
)
|
222 |
+
|
223 |
+
super().__init__(head, **kwargs)
|
224 |
+
|
225 |
+
if path is not None:
|
226 |
+
self.load(path)
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
modules/models/lseg_vit.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import clip
|
8 |
+
|
9 |
+
activations = {}
|
10 |
+
|
11 |
+
|
12 |
+
def get_activation(name):
|
13 |
+
def hook(model, input, output):
|
14 |
+
activations[name] = output
|
15 |
+
|
16 |
+
return hook
|
17 |
+
|
18 |
+
|
19 |
+
attention = {}
|
20 |
+
|
21 |
+
|
22 |
+
def get_attention(name):
|
23 |
+
def hook(module, input, output):
|
24 |
+
x = input[0]
|
25 |
+
B, N, C = x.shape
|
26 |
+
qkv = (
|
27 |
+
module.qkv(x)
|
28 |
+
.reshape(B, N, 3, module.num_heads, C // module.num_heads)
|
29 |
+
.permute(2, 0, 3, 1, 4)
|
30 |
+
)
|
31 |
+
q, k, v = (
|
32 |
+
qkv[0],
|
33 |
+
qkv[1],
|
34 |
+
qkv[2],
|
35 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
36 |
+
|
37 |
+
attn = (q @ k.transpose(-2, -1)) * module.scale
|
38 |
+
|
39 |
+
attn = attn.softmax(dim=-1) # [:,:,1,1:]
|
40 |
+
attention[name] = attn
|
41 |
+
|
42 |
+
return hook
|
43 |
+
|
44 |
+
|
45 |
+
def get_mean_attention_map(attn, token, shape):
|
46 |
+
attn = attn[:, :, token, 1:]
|
47 |
+
attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
|
48 |
+
attn = torch.nn.functional.interpolate(
|
49 |
+
attn, size=shape[2:], mode="bicubic", align_corners=False
|
50 |
+
).squeeze(0)
|
51 |
+
|
52 |
+
all_attn = torch.mean(attn, 0)
|
53 |
+
|
54 |
+
return all_attn
|
55 |
+
|
56 |
+
|
57 |
+
class Slice(nn.Module):
|
58 |
+
def __init__(self, start_index=1):
|
59 |
+
super(Slice, self).__init__()
|
60 |
+
self.start_index = start_index
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return x[:, self.start_index :]
|
64 |
+
|
65 |
+
|
66 |
+
class AddReadout(nn.Module):
|
67 |
+
def __init__(self, start_index=1):
|
68 |
+
super(AddReadout, self).__init__()
|
69 |
+
self.start_index = start_index
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.start_index == 2:
|
73 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
74 |
+
else:
|
75 |
+
readout = x[:, 0]
|
76 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
77 |
+
|
78 |
+
|
79 |
+
class ProjectReadout(nn.Module):
|
80 |
+
def __init__(self, in_features, start_index=1):
|
81 |
+
super(ProjectReadout, self).__init__()
|
82 |
+
self.start_index = start_index
|
83 |
+
|
84 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
88 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
89 |
+
|
90 |
+
return self.project(features)
|
91 |
+
|
92 |
+
|
93 |
+
class Transpose(nn.Module):
|
94 |
+
def __init__(self, dim0, dim1):
|
95 |
+
super(Transpose, self).__init__()
|
96 |
+
self.dim0 = dim0
|
97 |
+
self.dim1 = dim1
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
x = x.transpose(self.dim0, self.dim1)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
def forward_vit(pretrained, x):
|
105 |
+
b, c, h, w = x.shape
|
106 |
+
|
107 |
+
# encoder
|
108 |
+
glob = pretrained.model.forward_flex(x)
|
109 |
+
|
110 |
+
layer_1 = pretrained.activations["1"]
|
111 |
+
layer_2 = pretrained.activations["2"]
|
112 |
+
layer_3 = pretrained.activations["3"]
|
113 |
+
layer_4 = pretrained.activations["4"]
|
114 |
+
|
115 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
116 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
117 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
118 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
119 |
+
|
120 |
+
unflatten = nn.Sequential(
|
121 |
+
nn.Unflatten(
|
122 |
+
2,
|
123 |
+
torch.Size(
|
124 |
+
[
|
125 |
+
h // pretrained.model.patch_size[1],
|
126 |
+
w // pretrained.model.patch_size[0],
|
127 |
+
]
|
128 |
+
),
|
129 |
+
)
|
130 |
+
)
|
131 |
+
|
132 |
+
if layer_1.ndim == 3:
|
133 |
+
layer_1 = unflatten(layer_1)
|
134 |
+
if layer_2.ndim == 3:
|
135 |
+
layer_2 = unflatten(layer_2)
|
136 |
+
if layer_3.ndim == 3:
|
137 |
+
layer_3 = unflatten(layer_3)
|
138 |
+
if layer_4.ndim == 3:
|
139 |
+
layer_4 = unflatten(layer_4)
|
140 |
+
|
141 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
142 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
143 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
144 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
145 |
+
|
146 |
+
return layer_1, layer_2, layer_3, layer_4
|
147 |
+
|
148 |
+
|
149 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
150 |
+
posemb_tok, posemb_grid = (
|
151 |
+
posemb[:, : self.start_index],
|
152 |
+
posemb[0, self.start_index :],
|
153 |
+
)
|
154 |
+
|
155 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
156 |
+
|
157 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
158 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
159 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
160 |
+
|
161 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
162 |
+
|
163 |
+
return posemb
|
164 |
+
|
165 |
+
|
166 |
+
def forward_flex(self, x):
|
167 |
+
b, c, h, w = x.shape
|
168 |
+
|
169 |
+
pos_embed = self._resize_pos_embed(
|
170 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
171 |
+
)
|
172 |
+
|
173 |
+
B = x.shape[0]
|
174 |
+
|
175 |
+
if hasattr(self.patch_embed, "backbone"):
|
176 |
+
x = self.patch_embed.backbone(x)
|
177 |
+
if isinstance(x, (list, tuple)):
|
178 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
179 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
180 |
+
|
181 |
+
if getattr(self, "dist_token", None) is not None:
|
182 |
+
cls_tokens = self.cls_token.expand(
|
183 |
+
B, -1, -1
|
184 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
185 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
186 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
187 |
+
else:
|
188 |
+
cls_tokens = self.cls_token.expand(
|
189 |
+
B, -1, -1
|
190 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
191 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
192 |
+
|
193 |
+
x = x + pos_embed
|
194 |
+
x = self.pos_drop(x)
|
195 |
+
|
196 |
+
for blk in self.blocks:
|
197 |
+
x = blk(x)
|
198 |
+
|
199 |
+
x = self.norm(x)
|
200 |
+
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
205 |
+
if use_readout == "ignore":
|
206 |
+
readout_oper = [Slice(start_index)] * len(features)
|
207 |
+
elif use_readout == "add":
|
208 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
209 |
+
elif use_readout == "project":
|
210 |
+
readout_oper = [
|
211 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
212 |
+
]
|
213 |
+
else:
|
214 |
+
assert (
|
215 |
+
False
|
216 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
217 |
+
|
218 |
+
return readout_oper
|
219 |
+
|
220 |
+
|
221 |
+
def _make_pretrained_clip_vitl16_384(
|
222 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
223 |
+
):
|
224 |
+
clip_pretrained, _ = clip.load("ViT-B/32", device='cuda', jit=False)
|
225 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
226 |
+
|
227 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
228 |
+
|
229 |
+
pretrained = _make_vit_b16_backbone(
|
230 |
+
model,
|
231 |
+
features=[256, 512, 1024, 1024],
|
232 |
+
hooks=hooks,
|
233 |
+
vit_features=1024,
|
234 |
+
use_readout=use_readout,
|
235 |
+
enable_attention_hooks=enable_attention_hooks,
|
236 |
+
)
|
237 |
+
return clip_pretrained, pretrained
|
238 |
+
|
239 |
+
|
240 |
+
def _make_pretrained_clipRN50x16_vitl16_384(
|
241 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
242 |
+
):
|
243 |
+
clip_pretrained, _ = clip.load("RN50x16", device='cuda', jit=False)
|
244 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
245 |
+
|
246 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
247 |
+
|
248 |
+
pretrained = _make_vit_b16_backbone(
|
249 |
+
model,
|
250 |
+
features=[256, 512, 1024, 1024],
|
251 |
+
hooks=hooks,
|
252 |
+
vit_features=1024,
|
253 |
+
use_readout=use_readout,
|
254 |
+
enable_attention_hooks=enable_attention_hooks,
|
255 |
+
)
|
256 |
+
return clip_pretrained, pretrained
|
257 |
+
|
258 |
+
|
259 |
+
def _make_pretrained_clip_vitb32_384(pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False):
|
260 |
+
clip_pretrained, _ = clip.load("ViT-B/32", device='cuda', jit=False)
|
261 |
+
model = timm.create_model("vit_base_patch32_384", pretrained=pretrained)
|
262 |
+
|
263 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
264 |
+
|
265 |
+
pretrained = _make_vit_b32_backbone(
|
266 |
+
model,
|
267 |
+
features=[96, 192, 384, 768],
|
268 |
+
hooks=hooks,
|
269 |
+
use_readout=use_readout,
|
270 |
+
enable_attention_hooks=False,
|
271 |
+
)
|
272 |
+
return clip_pretrained, pretrained
|
273 |
+
|
274 |
+
|
275 |
+
def _make_vit_b32_backbone(
|
276 |
+
model,
|
277 |
+
features=[96, 192, 384, 768],
|
278 |
+
size=[384, 384],
|
279 |
+
hooks=[2, 5, 8, 11],
|
280 |
+
vit_features=768,
|
281 |
+
use_readout="ignore",
|
282 |
+
start_index=1,
|
283 |
+
enable_attention_hooks=False,
|
284 |
+
):
|
285 |
+
pretrained = nn.Module()
|
286 |
+
|
287 |
+
pretrained.model = model
|
288 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
289 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
290 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
291 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
292 |
+
|
293 |
+
pretrained.activations = activations
|
294 |
+
|
295 |
+
pretrained.model.patch_size = [32, 32]
|
296 |
+
pretrained.model.start_index = start_index
|
297 |
+
|
298 |
+
if enable_attention_hooks:
|
299 |
+
pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
|
300 |
+
get_attention("attn_1")
|
301 |
+
)
|
302 |
+
pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
|
303 |
+
get_attention("attn_2")
|
304 |
+
)
|
305 |
+
pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
|
306 |
+
get_attention("attn_3")
|
307 |
+
)
|
308 |
+
pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
|
309 |
+
get_attention("attn_4")
|
310 |
+
)
|
311 |
+
pretrained.attention = attention
|
312 |
+
|
313 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
314 |
+
|
315 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
316 |
+
readout_oper[0],
|
317 |
+
Transpose(1, 2),
|
318 |
+
nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
|
319 |
+
nn.Conv2d(
|
320 |
+
in_channels=vit_features,
|
321 |
+
out_channels=features[0],
|
322 |
+
kernel_size=1,
|
323 |
+
stride=1,
|
324 |
+
padding=0,
|
325 |
+
),
|
326 |
+
nn.ConvTranspose2d(
|
327 |
+
in_channels=features[0],
|
328 |
+
out_channels=features[0],
|
329 |
+
kernel_size=8,
|
330 |
+
stride=8,
|
331 |
+
padding=0,
|
332 |
+
bias=True,
|
333 |
+
dilation=1,
|
334 |
+
groups=1,
|
335 |
+
),
|
336 |
+
)
|
337 |
+
|
338 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
339 |
+
readout_oper[1],
|
340 |
+
Transpose(1, 2),
|
341 |
+
nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
|
342 |
+
nn.Conv2d(
|
343 |
+
in_channels=vit_features,
|
344 |
+
out_channels=features[1],
|
345 |
+
kernel_size=1,
|
346 |
+
stride=1,
|
347 |
+
padding=0,
|
348 |
+
),
|
349 |
+
nn.ConvTranspose2d(
|
350 |
+
in_channels=features[1],
|
351 |
+
out_channels=features[1],
|
352 |
+
kernel_size=4,
|
353 |
+
stride=4,
|
354 |
+
padding=0,
|
355 |
+
bias=True,
|
356 |
+
dilation=1,
|
357 |
+
groups=1,
|
358 |
+
),
|
359 |
+
)
|
360 |
+
|
361 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
362 |
+
readout_oper[2],
|
363 |
+
Transpose(1, 2),
|
364 |
+
nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
|
365 |
+
nn.Conv2d(
|
366 |
+
in_channels=vit_features,
|
367 |
+
out_channels=features[2],
|
368 |
+
kernel_size=1,
|
369 |
+
stride=1,
|
370 |
+
padding=0,
|
371 |
+
),
|
372 |
+
nn.ConvTranspose2d(
|
373 |
+
in_channels=features[2],
|
374 |
+
out_channels=features[2],
|
375 |
+
kernel_size=2,
|
376 |
+
stride=2,
|
377 |
+
padding=0,
|
378 |
+
# output_padding=output_padding,
|
379 |
+
bias=True,
|
380 |
+
dilation=1,
|
381 |
+
groups=1,
|
382 |
+
),
|
383 |
+
)
|
384 |
+
|
385 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
386 |
+
readout_oper[3],
|
387 |
+
Transpose(1, 2),
|
388 |
+
nn.Unflatten(2, torch.Size([size[0] // pretrained.model.patch_size[1], size[1] // pretrained.model.patch_size[0]])),
|
389 |
+
nn.Conv2d(
|
390 |
+
in_channels=vit_features,
|
391 |
+
out_channels=features[3],
|
392 |
+
kernel_size=1,
|
393 |
+
stride=1,
|
394 |
+
padding=0,
|
395 |
+
),
|
396 |
+
)
|
397 |
+
|
398 |
+
# We inject this function into the VisionTransformer instances so that
|
399 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
400 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
401 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
402 |
+
_resize_pos_embed, pretrained.model
|
403 |
+
)
|
404 |
+
|
405 |
+
return pretrained
|
406 |
+
|
407 |
+
|
408 |
+
def _make_vit_b16_backbone(
|
409 |
+
model,
|
410 |
+
features=[96, 192, 384, 768],
|
411 |
+
size=[384, 384],
|
412 |
+
hooks=[2, 5, 8, 11],
|
413 |
+
vit_features=768,
|
414 |
+
use_readout="ignore",
|
415 |
+
start_index=1,
|
416 |
+
enable_attention_hooks=False,
|
417 |
+
):
|
418 |
+
pretrained = nn.Module()
|
419 |
+
|
420 |
+
pretrained.model = model
|
421 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
422 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
423 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
424 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
425 |
+
|
426 |
+
pretrained.activations = activations
|
427 |
+
|
428 |
+
if enable_attention_hooks:
|
429 |
+
pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
|
430 |
+
get_attention("attn_1")
|
431 |
+
)
|
432 |
+
pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
|
433 |
+
get_attention("attn_2")
|
434 |
+
)
|
435 |
+
pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
|
436 |
+
get_attention("attn_3")
|
437 |
+
)
|
438 |
+
pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
|
439 |
+
get_attention("attn_4")
|
440 |
+
)
|
441 |
+
pretrained.attention = attention
|
442 |
+
|
443 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
444 |
+
|
445 |
+
# 32, 48, 136, 384
|
446 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
447 |
+
readout_oper[0],
|
448 |
+
Transpose(1, 2),
|
449 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
450 |
+
nn.Conv2d(
|
451 |
+
in_channels=vit_features,
|
452 |
+
out_channels=features[0],
|
453 |
+
kernel_size=1,
|
454 |
+
stride=1,
|
455 |
+
padding=0,
|
456 |
+
),
|
457 |
+
nn.ConvTranspose2d(
|
458 |
+
in_channels=features[0],
|
459 |
+
out_channels=features[0],
|
460 |
+
kernel_size=4,
|
461 |
+
stride=4,
|
462 |
+
padding=0,
|
463 |
+
bias=True,
|
464 |
+
dilation=1,
|
465 |
+
groups=1,
|
466 |
+
),
|
467 |
+
)
|
468 |
+
|
469 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
470 |
+
readout_oper[1],
|
471 |
+
Transpose(1, 2),
|
472 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
473 |
+
nn.Conv2d(
|
474 |
+
in_channels=vit_features,
|
475 |
+
out_channels=features[1],
|
476 |
+
kernel_size=1,
|
477 |
+
stride=1,
|
478 |
+
padding=0,
|
479 |
+
),
|
480 |
+
nn.ConvTranspose2d(
|
481 |
+
in_channels=features[1],
|
482 |
+
out_channels=features[1],
|
483 |
+
kernel_size=2,
|
484 |
+
stride=2,
|
485 |
+
padding=0,
|
486 |
+
bias=True,
|
487 |
+
dilation=1,
|
488 |
+
groups=1,
|
489 |
+
),
|
490 |
+
)
|
491 |
+
|
492 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
493 |
+
readout_oper[2],
|
494 |
+
Transpose(1, 2),
|
495 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
496 |
+
nn.Conv2d(
|
497 |
+
in_channels=vit_features,
|
498 |
+
out_channels=features[2],
|
499 |
+
kernel_size=1,
|
500 |
+
stride=1,
|
501 |
+
padding=0,
|
502 |
+
),
|
503 |
+
)
|
504 |
+
|
505 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
506 |
+
readout_oper[3],
|
507 |
+
Transpose(1, 2),
|
508 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
509 |
+
nn.Conv2d(
|
510 |
+
in_channels=vit_features,
|
511 |
+
out_channels=features[3],
|
512 |
+
kernel_size=1,
|
513 |
+
stride=1,
|
514 |
+
padding=0,
|
515 |
+
),
|
516 |
+
nn.Conv2d(
|
517 |
+
in_channels=features[3],
|
518 |
+
out_channels=features[3],
|
519 |
+
kernel_size=3,
|
520 |
+
stride=2,
|
521 |
+
padding=1,
|
522 |
+
),
|
523 |
+
)
|
524 |
+
|
525 |
+
pretrained.model.start_index = start_index
|
526 |
+
pretrained.model.patch_size = [16, 16]
|
527 |
+
|
528 |
+
# We inject this function into the VisionTransformer instances so that
|
529 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
530 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
531 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
532 |
+
_resize_pos_embed, pretrained.model
|
533 |
+
)
|
534 |
+
|
535 |
+
return pretrained
|
prepare_ade20k.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# +
|
2 |
+
# revised from https://github.com/zhanghang1989/PyTorch-Encoding/blob/331ecdd5306104614cb414b16fbcd9d1a8d40e1e/scripts/prepare_ade20k.py
|
3 |
+
|
4 |
+
"""Prepare ADE20K dataset"""
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import argparse
|
8 |
+
import zipfile
|
9 |
+
from encoding.utils import download, mkdir
|
10 |
+
# -
|
11 |
+
|
12 |
+
_TARGET_DIR = os.path.expanduser('../datasets/')
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser(
|
16 |
+
description='Initialize ADE20K dataset.',
|
17 |
+
epilog='Example: python prepare_ade20k.py',
|
18 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
19 |
+
parser.add_argument('--download-dir', default=None, help='dataset directory on disk')
|
20 |
+
args = parser.parse_args()
|
21 |
+
return args
|
22 |
+
|
23 |
+
def download_ade(path, overwrite=False):
|
24 |
+
_AUG_DOWNLOAD_URLS = [
|
25 |
+
('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
|
26 |
+
('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),]
|
27 |
+
download_dir = path
|
28 |
+
mkdir(download_dir)
|
29 |
+
for url, checksum in _AUG_DOWNLOAD_URLS:
|
30 |
+
filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
|
31 |
+
# extract
|
32 |
+
with zipfile.ZipFile(filename,"r") as zip_ref:
|
33 |
+
zip_ref.extractall(path=path)
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
args = parse_args()
|
38 |
+
mkdir(os.path.expanduser('../datasets/'))
|
39 |
+
if args.download_dir is not None:
|
40 |
+
if os.path.isdir(_TARGET_DIR):
|
41 |
+
os.remove(_TARGET_DIR)
|
42 |
+
# make symlink
|
43 |
+
os.symlink(args.download_dir, _TARGET_DIR)
|
44 |
+
else:
|
45 |
+
download_ade(_TARGET_DIR, overwrite=False)
|
test_lseg.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from collections import OrderedDict
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.utils import data
|
9 |
+
import torchvision.transforms as transform
|
10 |
+
from torch.nn.parallel.scatter_gather import gather
|
11 |
+
import encoding.utils as utils
|
12 |
+
from encoding.nn import SegmentationLosses, SyncBatchNorm
|
13 |
+
from encoding.parallel import DataParallelModel, DataParallelCriterion
|
14 |
+
from encoding.datasets import test_batchify_fn
|
15 |
+
from encoding.models.sseg import BaseNet
|
16 |
+
from modules.lseg_module import LSegModule
|
17 |
+
from utils import Resize
|
18 |
+
import cv2
|
19 |
+
import math
|
20 |
+
import types
|
21 |
+
import functools
|
22 |
+
import torchvision.transforms as torch_transforms
|
23 |
+
import copy
|
24 |
+
import itertools
|
25 |
+
from PIL import Image
|
26 |
+
import matplotlib.pyplot as plt
|
27 |
+
import clip
|
28 |
+
import matplotlib as mpl
|
29 |
+
import matplotlib.colors as mplc
|
30 |
+
import matplotlib.figure as mplfigure
|
31 |
+
import matplotlib.patches as mpatches
|
32 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
33 |
+
from data import get_dataset
|
34 |
+
from additional_utils.encoding_models import MultiEvalModule as LSeg_MultiEvalModule
|
35 |
+
import torchvision.transforms as transforms
|
36 |
+
|
37 |
+
class Options:
|
38 |
+
def __init__(self):
|
39 |
+
parser = argparse.ArgumentParser(description="PyTorch Segmentation")
|
40 |
+
# model and dataset
|
41 |
+
parser.add_argument(
|
42 |
+
"--model", type=str, default="encnet", help="model name (default: encnet)"
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--backbone",
|
46 |
+
type=str,
|
47 |
+
default="clip_vitl16_384",
|
48 |
+
help="backbone name (default: resnet50)",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--dataset",
|
52 |
+
type=str,
|
53 |
+
default="ade20k",
|
54 |
+
help="dataset name (default: pascal12)",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--workers", type=int, default=16, metavar="N", help="dataloader threads"
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--base-size", type=int, default=520, help="base image size"
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--crop-size", type=int, default=480, help="crop image size"
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--train-split",
|
67 |
+
type=str,
|
68 |
+
default="train",
|
69 |
+
help="dataset train split (default: train)",
|
70 |
+
)
|
71 |
+
# training hyper params
|
72 |
+
parser.add_argument(
|
73 |
+
"--aux", action="store_true", default=False, help="Auxilary Loss"
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--se-loss",
|
77 |
+
action="store_true",
|
78 |
+
default=False,
|
79 |
+
help="Semantic Encoding Loss SE-loss",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--batch-size",
|
86 |
+
type=int,
|
87 |
+
default=16,
|
88 |
+
metavar="N",
|
89 |
+
help="input batch size for \
|
90 |
+
training (default: auto)",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--test-batch-size",
|
94 |
+
type=int,
|
95 |
+
default=16,
|
96 |
+
metavar="N",
|
97 |
+
help="input batch size for \
|
98 |
+
testing (default: same as batch size)",
|
99 |
+
)
|
100 |
+
# cuda, seed and logging
|
101 |
+
parser.add_argument(
|
102 |
+
"--no-cuda",
|
103 |
+
action="store_true",
|
104 |
+
default=False,
|
105 |
+
help="disables CUDA training",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--weights", type=str, default=None, help="checkpoint to test"
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--eval", action="store_true", default=False, help="evaluating mIoU"
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--export",
|
118 |
+
type=str,
|
119 |
+
default=None,
|
120 |
+
help="put the path to resuming file if needed",
|
121 |
+
)
|
122 |
+
|
123 |
+
parser.add_argument(
|
124 |
+
"--acc-bn",
|
125 |
+
action="store_true",
|
126 |
+
default=False,
|
127 |
+
help="Re-accumulate BN statistics",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--test-val",
|
131 |
+
action="store_true",
|
132 |
+
default=False,
|
133 |
+
help="generate masks on val set",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--no-val",
|
137 |
+
action="store_true",
|
138 |
+
default=False,
|
139 |
+
help="skip validation during training",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--module",
|
143 |
+
default='lseg',
|
144 |
+
help="select model definition",
|
145 |
+
)
|
146 |
+
# test option
|
147 |
+
parser.add_argument(
|
148 |
+
"--data-path", type=str, default=None, help="path to test image folder"
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--no-scaleinv",
|
152 |
+
dest="scale_inv",
|
153 |
+
default=True,
|
154 |
+
action="store_false",
|
155 |
+
help="turn off scaleinv layers",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--widehead", default=False, action="store_true", help="wider output head"
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--widehead_hr",
|
162 |
+
default=False,
|
163 |
+
action="store_true",
|
164 |
+
help="wider output head",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--ignore_index",
|
168 |
+
type=int,
|
169 |
+
default=-1,
|
170 |
+
help="numeric value of ignore label in gt",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--label_src",
|
174 |
+
type=str,
|
175 |
+
default="default",
|
176 |
+
help="how to get the labels",
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--jobname",
|
180 |
+
type=str,
|
181 |
+
default="default",
|
182 |
+
help="select which dataset",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--no-strict",
|
186 |
+
dest="strict",
|
187 |
+
default=True,
|
188 |
+
action="store_false",
|
189 |
+
help="no-strict copy the model",
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--arch_option",
|
193 |
+
type=int,
|
194 |
+
default=0,
|
195 |
+
help="which kind of architecture to be used",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--block_depth",
|
199 |
+
type=int,
|
200 |
+
default=0,
|
201 |
+
help="how many blocks should be used",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--activation",
|
205 |
+
choices=['lrelu', 'tanh'],
|
206 |
+
default="lrelu",
|
207 |
+
help="use which activation to activate the block",
|
208 |
+
)
|
209 |
+
|
210 |
+
self.parser = parser
|
211 |
+
|
212 |
+
def parse(self):
|
213 |
+
args = self.parser.parse_args()
|
214 |
+
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
215 |
+
print(args)
|
216 |
+
return args
|
217 |
+
|
218 |
+
|
219 |
+
def test(args):
|
220 |
+
|
221 |
+
module = LSegModule.load_from_checkpoint(
|
222 |
+
checkpoint_path=args.weights,
|
223 |
+
data_path=args.data_path,
|
224 |
+
dataset=args.dataset,
|
225 |
+
backbone=args.backbone,
|
226 |
+
aux=args.aux,
|
227 |
+
num_features=256,
|
228 |
+
aux_weight=0,
|
229 |
+
se_loss=False,
|
230 |
+
se_weight=0,
|
231 |
+
base_lr=0,
|
232 |
+
batch_size=1,
|
233 |
+
max_epochs=0,
|
234 |
+
ignore_index=args.ignore_index,
|
235 |
+
dropout=0.0,
|
236 |
+
scale_inv=args.scale_inv,
|
237 |
+
augment=False,
|
238 |
+
no_batchnorm=False,
|
239 |
+
widehead=args.widehead,
|
240 |
+
widehead_hr=args.widehead_hr,
|
241 |
+
map_locatin="cpu",
|
242 |
+
arch_option=args.arch_option,
|
243 |
+
strict=args.strict,
|
244 |
+
block_depth=args.block_depth,
|
245 |
+
activation=args.activation,
|
246 |
+
)
|
247 |
+
input_transform = module.val_transform
|
248 |
+
num_classes = module.num_classes
|
249 |
+
|
250 |
+
# dataset
|
251 |
+
testset = get_dataset(
|
252 |
+
args.dataset,
|
253 |
+
root=args.data_path,
|
254 |
+
split="val",
|
255 |
+
mode="testval",
|
256 |
+
transform=input_transform,
|
257 |
+
)
|
258 |
+
|
259 |
+
# dataloader
|
260 |
+
loader_kwargs = (
|
261 |
+
{"num_workers": args.workers, "pin_memory": True} if args.cuda else {}
|
262 |
+
)
|
263 |
+
test_data = data.DataLoader(
|
264 |
+
testset,
|
265 |
+
batch_size=args.test_batch_size,
|
266 |
+
drop_last=False,
|
267 |
+
shuffle=False,
|
268 |
+
collate_fn=test_batchify_fn,
|
269 |
+
**loader_kwargs
|
270 |
+
)
|
271 |
+
|
272 |
+
if isinstance(module.net, BaseNet):
|
273 |
+
model = module.net
|
274 |
+
else:
|
275 |
+
model = module
|
276 |
+
|
277 |
+
model = model.eval()
|
278 |
+
model = model.cpu()
|
279 |
+
|
280 |
+
print(model)
|
281 |
+
if args.acc_bn:
|
282 |
+
from encoding.utils.precise_bn import update_bn_stats
|
283 |
+
|
284 |
+
data_kwargs = {
|
285 |
+
"transform": input_transform,
|
286 |
+
"base_size": args.base_size,
|
287 |
+
"crop_size": args.crop_size,
|
288 |
+
}
|
289 |
+
trainset = get_dataset(
|
290 |
+
args.dataset, split=args.train_split, mode="train", **data_kwargs
|
291 |
+
)
|
292 |
+
trainloader = data.DataLoader(
|
293 |
+
ReturnFirstClosure(trainset),
|
294 |
+
root=args.data_path,
|
295 |
+
batch_size=args.batch_size,
|
296 |
+
drop_last=True,
|
297 |
+
shuffle=True,
|
298 |
+
**loader_kwargs
|
299 |
+
)
|
300 |
+
print("Reseting BN statistics")
|
301 |
+
model.cuda()
|
302 |
+
update_bn_stats(model, trainloader)
|
303 |
+
|
304 |
+
if args.export:
|
305 |
+
torch.save(model.state_dict(), args.export + ".pth")
|
306 |
+
return
|
307 |
+
|
308 |
+
scales = (
|
309 |
+
[0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]
|
310 |
+
if args.dataset == "citys"
|
311 |
+
else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
312 |
+
)
|
313 |
+
|
314 |
+
evaluator = LSeg_MultiEvalModule(
|
315 |
+
model, num_classes, scales=scales, flip=True
|
316 |
+
).cuda()
|
317 |
+
evaluator.eval()
|
318 |
+
|
319 |
+
metric = utils.SegmentationMetric(testset.num_class)
|
320 |
+
tbar = tqdm(test_data)
|
321 |
+
|
322 |
+
f = open("logs/log_test_{}_{}.txt".format(args.jobname, args.dataset), "a+")
|
323 |
+
per_class_iou = np.zeros(testset.num_class)
|
324 |
+
cnt = 0
|
325 |
+
for i, (image, dst) in enumerate(tbar):
|
326 |
+
if args.eval:
|
327 |
+
with torch.no_grad():
|
328 |
+
if False:
|
329 |
+
sample = {"image": image[0].cpu().permute(1, 2, 0).numpy()}
|
330 |
+
out = torch.zeros(
|
331 |
+
1, testset.num_class, image[0].shape[1], image[0].shape[2]
|
332 |
+
).cuda()
|
333 |
+
|
334 |
+
H, W = image[0].shape[1], image[0].shape[2]
|
335 |
+
for scale in scales:
|
336 |
+
long_size = int(math.ceil(520 * scale))
|
337 |
+
if H > W:
|
338 |
+
height = long_size
|
339 |
+
width = int(1.0 * W * long_size / H + 0.5)
|
340 |
+
short_size = width
|
341 |
+
else:
|
342 |
+
width = long_size
|
343 |
+
height = int(1.0 * H * long_size / W + 0.5)
|
344 |
+
short_size = height
|
345 |
+
|
346 |
+
rs = Resize(
|
347 |
+
width,
|
348 |
+
height,
|
349 |
+
resize_target=False,
|
350 |
+
keep_aspect_ratio=True,
|
351 |
+
ensure_multiple_of=32,
|
352 |
+
resize_method="minimal",
|
353 |
+
image_interpolation_method=cv2.INTER_AREA,
|
354 |
+
)
|
355 |
+
|
356 |
+
inf_image = (
|
357 |
+
torch.from_numpy(rs(sample)["image"])
|
358 |
+
.cuda()
|
359 |
+
.permute(2, 0, 1)
|
360 |
+
.unsqueeze(0)
|
361 |
+
)
|
362 |
+
inf_image = torch.cat((inf_image, torch.fliplr(inf_image)), 0)
|
363 |
+
try:
|
364 |
+
pred = model(inf_image)
|
365 |
+
except:
|
366 |
+
print(H, W, sz, i)
|
367 |
+
exit()
|
368 |
+
|
369 |
+
pred0 = F.softmax(pred[0], dim=1)
|
370 |
+
pred1 = F.softmax(pred[1], dim=1)
|
371 |
+
|
372 |
+
pred = pred0 + 0.2 * pred1
|
373 |
+
|
374 |
+
out += F.interpolate(
|
375 |
+
pred.sum(0, keepdim=True),
|
376 |
+
(out.shape[2], out.shape[3]),
|
377 |
+
mode="bilinear",
|
378 |
+
align_corners=True,
|
379 |
+
)
|
380 |
+
|
381 |
+
predicts = [out]
|
382 |
+
else:
|
383 |
+
predicts = evaluator.parallel_forward(image)
|
384 |
+
|
385 |
+
metric.update(dst, predicts)
|
386 |
+
pixAcc, mIoU = metric.get()
|
387 |
+
|
388 |
+
_, _, total_inter, total_union = metric.get_all()
|
389 |
+
per_class_iou += 1.0 * total_inter / (np.spacing(1) + total_union)
|
390 |
+
cnt+=1
|
391 |
+
|
392 |
+
tbar.set_description("pixAcc: %.4f, mIoU: %.4f" % (pixAcc, mIoU))
|
393 |
+
else:
|
394 |
+
with torch.no_grad():
|
395 |
+
outputs = evaluator.parallel_forward(image)
|
396 |
+
predicts = [
|
397 |
+
testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
|
398 |
+
for output in outputs
|
399 |
+
]
|
400 |
+
|
401 |
+
# output folder
|
402 |
+
outdir = "outdir_ours"
|
403 |
+
if not os.path.exists(outdir):
|
404 |
+
os.makedirs(outdir)
|
405 |
+
|
406 |
+
for predict, impath in zip(predicts, dst):
|
407 |
+
mask = utils.get_mask_pallete(predict, args.dataset)
|
408 |
+
outname = os.path.splitext(impath)[0] + ".png"
|
409 |
+
mask.save(os.path.join(outdir, outname))
|
410 |
+
|
411 |
+
if args.eval:
|
412 |
+
each_classes_iou = per_class_iou/cnt
|
413 |
+
print("pixAcc: %.4f, mIoU: %.4f" % (pixAcc, mIoU))
|
414 |
+
print(each_classes_iou)
|
415 |
+
f.write("dataset {} ==> pixAcc: {:.4f}, mIoU: {:.4f}\n".format(args.dataset, pixAcc, mIoU))
|
416 |
+
for per_iou in each_classes_iou: f.write('{:.4f}, '.format(per_iou))
|
417 |
+
f.write('\n')
|
418 |
+
|
419 |
+
|
420 |
+
class ReturnFirstClosure(object):
|
421 |
+
def __init__(self, data):
|
422 |
+
self._data = data
|
423 |
+
|
424 |
+
def __len__(self):
|
425 |
+
return len(self._data)
|
426 |
+
|
427 |
+
def __getitem__(self, idx):
|
428 |
+
outputs = self._data[idx]
|
429 |
+
return outputs[0]
|
430 |
+
|
431 |
+
|
432 |
+
if __name__ == "__main__":
|
433 |
+
args = Options().parse()
|
434 |
+
torch.manual_seed(args.seed)
|
435 |
+
args.test_batch_size = torch.cuda.device_count()
|
436 |
+
test(args)
|
train_lseg.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.lseg_module import LSegModule
|
2 |
+
from utils import do_training, get_default_argument_parser
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
parser = LSegModule.add_model_specific_args(get_default_argument_parser())
|
6 |
+
args = parser.parse_args()
|
7 |
+
do_training(args, LSegModule)
|
utils.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
|
4 |
+
from glob import glob
|
5 |
+
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
import torch
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import random
|
12 |
+
import math
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
|
16 |
+
def do_training(hparams, model_constructor):
|
17 |
+
# instantiate model
|
18 |
+
model = model_constructor(**vars(hparams))
|
19 |
+
# set all sorts of training parameters
|
20 |
+
hparams.gpus = -1
|
21 |
+
hparams.accelerator = "ddp"
|
22 |
+
hparams.benchmark = True
|
23 |
+
|
24 |
+
if hparams.dry_run:
|
25 |
+
print("Doing a dry run")
|
26 |
+
hparams.overfit_batches = hparams.batch_size
|
27 |
+
|
28 |
+
if not hparams.no_resume:
|
29 |
+
hparams = set_resume_parameters(hparams)
|
30 |
+
|
31 |
+
if not hasattr(hparams, "version") or hparams.version is None:
|
32 |
+
hparams.version = 0
|
33 |
+
|
34 |
+
hparams.sync_batchnorm = True
|
35 |
+
|
36 |
+
ttlogger = pl.loggers.TestTubeLogger(
|
37 |
+
"checkpoints", name=hparams.exp_name, version=hparams.version
|
38 |
+
)
|
39 |
+
|
40 |
+
hparams.callbacks = make_checkpoint_callbacks(hparams.exp_name, hparams.version)
|
41 |
+
|
42 |
+
wblogger = get_wandb_logger(hparams)
|
43 |
+
hparams.logger = [wblogger, ttlogger]
|
44 |
+
|
45 |
+
trainer = pl.Trainer.from_argparse_args(hparams)
|
46 |
+
trainer.fit(model)
|
47 |
+
|
48 |
+
|
49 |
+
def get_default_argument_parser():
|
50 |
+
parser = ArgumentParser(add_help=False)
|
51 |
+
parser.add_argument(
|
52 |
+
"--num_nodes",
|
53 |
+
type=int,
|
54 |
+
default=1,
|
55 |
+
help="number of nodes for distributed training",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--exp_name", type=str, required=True, help="name your experiment"
|
60 |
+
)
|
61 |
+
|
62 |
+
parser.add_argument(
|
63 |
+
"--dry-run",
|
64 |
+
action="store_true",
|
65 |
+
default=False,
|
66 |
+
help="run on batch of train/val/test",
|
67 |
+
)
|
68 |
+
|
69 |
+
parser.add_argument(
|
70 |
+
"--no_resume",
|
71 |
+
action="store_true",
|
72 |
+
default=False,
|
73 |
+
help="resume if we have a checkpoint",
|
74 |
+
)
|
75 |
+
|
76 |
+
parser.add_argument(
|
77 |
+
"--accumulate_grad_batches",
|
78 |
+
type=int,
|
79 |
+
default=1,
|
80 |
+
help="accumulate N batches for gradient computation",
|
81 |
+
)
|
82 |
+
|
83 |
+
parser.add_argument(
|
84 |
+
"--max_epochs", type=int, default=200, help="maximum number of epochs"
|
85 |
+
)
|
86 |
+
|
87 |
+
parser.add_argument(
|
88 |
+
"--project_name", type=str, default="lightseg", help="project name for logging"
|
89 |
+
)
|
90 |
+
|
91 |
+
return parser
|
92 |
+
|
93 |
+
|
94 |
+
def make_checkpoint_callbacks(exp_name, version, base_path="checkpoints", frequency=1):
|
95 |
+
version = 0 if version is None else version
|
96 |
+
|
97 |
+
base_callback = pl.callbacks.ModelCheckpoint(
|
98 |
+
dirpath=f"{base_path}/{exp_name}/version_{version}/checkpoints/",
|
99 |
+
save_last=True,
|
100 |
+
verbose=True,
|
101 |
+
)
|
102 |
+
|
103 |
+
val_callback = pl.callbacks.ModelCheckpoint(
|
104 |
+
monitor="val_acc_epoch",
|
105 |
+
dirpath=f"{base_path}/{exp_name}/version_{version}/checkpoints/",
|
106 |
+
filename="result-{epoch}-{val_acc_epoch:.2f}",
|
107 |
+
mode="max",
|
108 |
+
save_top_k=3,
|
109 |
+
verbose=True,
|
110 |
+
)
|
111 |
+
|
112 |
+
return [base_callback, val_callback]
|
113 |
+
|
114 |
+
|
115 |
+
def get_latest_version(folder):
|
116 |
+
versions = [
|
117 |
+
int(pathlib.PurePath(path).name.split("_")[-1])
|
118 |
+
for path in glob(f"{folder}/version_*/")
|
119 |
+
]
|
120 |
+
|
121 |
+
if len(versions) == 0:
|
122 |
+
return None
|
123 |
+
|
124 |
+
versions.sort()
|
125 |
+
return versions[-1]
|
126 |
+
|
127 |
+
|
128 |
+
def get_latest_checkpoint(exp_name, version):
|
129 |
+
while version > -1:
|
130 |
+
folder = f"./checkpoints/{exp_name}/version_{version}/checkpoints/"
|
131 |
+
|
132 |
+
latest = f"{folder}/last.ckpt"
|
133 |
+
if os.path.exists(latest):
|
134 |
+
return latest, version
|
135 |
+
|
136 |
+
chkpts = glob(f"{folder}/epoch=*.ckpt")
|
137 |
+
|
138 |
+
if len(chkpts) > 0:
|
139 |
+
break
|
140 |
+
|
141 |
+
version -= 1
|
142 |
+
|
143 |
+
if len(chkpts) == 0:
|
144 |
+
return None, None
|
145 |
+
|
146 |
+
latest = max(chkpts, key=os.path.getctime)
|
147 |
+
|
148 |
+
return latest, version
|
149 |
+
|
150 |
+
|
151 |
+
def set_resume_parameters(hparams):
|
152 |
+
version = get_latest_version(f"./checkpoints/{hparams.exp_name}")
|
153 |
+
|
154 |
+
if version is not None:
|
155 |
+
latest, version = get_latest_checkpoint(hparams.exp_name, version)
|
156 |
+
print(f"Resuming checkpoint {latest}, exp_version={version}")
|
157 |
+
|
158 |
+
hparams.resume_from_checkpoint = latest
|
159 |
+
hparams.version = version
|
160 |
+
|
161 |
+
wandb_file = "checkpoints/{hparams.exp_name}/version_{version}/wandb_id"
|
162 |
+
if os.path.exists(wandb_file):
|
163 |
+
with open(wandb_file, "r") as f:
|
164 |
+
hparams.wandb_id = f.read()
|
165 |
+
else:
|
166 |
+
version = 0
|
167 |
+
|
168 |
+
return hparams
|
169 |
+
|
170 |
+
|
171 |
+
def get_wandb_logger(hparams):
|
172 |
+
exp_dir = f"checkpoints/{hparams.exp_name}/version_{hparams.version}/"
|
173 |
+
id_file = f"{exp_dir}/wandb_id"
|
174 |
+
|
175 |
+
if os.path.exists(id_file):
|
176 |
+
with open(id_file) as f:
|
177 |
+
hparams.wandb_id = f.read()
|
178 |
+
else:
|
179 |
+
hparams.wandb_id = None
|
180 |
+
|
181 |
+
logger = pl.loggers.WandbLogger(
|
182 |
+
save_dir="checkpoints",
|
183 |
+
project=hparams.project_name,
|
184 |
+
name=hparams.exp_name,
|
185 |
+
id=hparams.wandb_id,
|
186 |
+
)
|
187 |
+
|
188 |
+
if hparams.wandb_id is None:
|
189 |
+
_ = logger.experiment
|
190 |
+
|
191 |
+
if not os.path.exists(exp_dir):
|
192 |
+
os.makedirs(exp_dir)
|
193 |
+
|
194 |
+
with open(id_file, "w") as f:
|
195 |
+
f.write(logger.version)
|
196 |
+
|
197 |
+
return logger
|
198 |
+
|
199 |
+
|
200 |
+
class Resize(object):
|
201 |
+
"""Resize sample to given size (width, height)."""
|
202 |
+
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
width,
|
206 |
+
height,
|
207 |
+
resize_target=True,
|
208 |
+
keep_aspect_ratio=False,
|
209 |
+
ensure_multiple_of=1,
|
210 |
+
resize_method="lower_bound",
|
211 |
+
image_interpolation_method=cv2.INTER_AREA,
|
212 |
+
letter_box=False,
|
213 |
+
):
|
214 |
+
"""Init.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
width (int): desired output width
|
218 |
+
height (int): desired output height
|
219 |
+
resize_target (bool, optional):
|
220 |
+
True: Resize the full sample (image, mask, target).
|
221 |
+
False: Resize image only.
|
222 |
+
Defaults to True.
|
223 |
+
keep_aspect_ratio (bool, optional):
|
224 |
+
True: Keep the aspect ratio of the input sample.
|
225 |
+
Output sample might not have the given width and height, and
|
226 |
+
resize behaviour depends on the parameter 'resize_method'.
|
227 |
+
Defaults to False.
|
228 |
+
ensure_multiple_of (int, optional):
|
229 |
+
Output width and height is constrained to be multiple of this parameter.
|
230 |
+
Defaults to 1.
|
231 |
+
resize_method (str, optional):
|
232 |
+
"lower_bound": Output will be at least as large as the given size.
|
233 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
234 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
235 |
+
Defaults to "lower_bound".
|
236 |
+
"""
|
237 |
+
self.__width = width
|
238 |
+
self.__height = height
|
239 |
+
|
240 |
+
self.__resize_target = resize_target
|
241 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
242 |
+
self.__multiple_of = ensure_multiple_of
|
243 |
+
self.__resize_method = resize_method
|
244 |
+
self.__image_interpolation_method = image_interpolation_method
|
245 |
+
self.__letter_box = letter_box
|
246 |
+
|
247 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
248 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
249 |
+
|
250 |
+
if max_val is not None and y > max_val:
|
251 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
252 |
+
|
253 |
+
if y < min_val:
|
254 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
255 |
+
|
256 |
+
return y
|
257 |
+
|
258 |
+
def get_size(self, width, height):
|
259 |
+
# determine new height and width
|
260 |
+
scale_height = self.__height / height
|
261 |
+
scale_width = self.__width / width
|
262 |
+
|
263 |
+
if self.__keep_aspect_ratio:
|
264 |
+
if self.__resize_method == "lower_bound":
|
265 |
+
# scale such that output size is lower bound
|
266 |
+
if scale_width > scale_height:
|
267 |
+
# fit width
|
268 |
+
scale_height = scale_width
|
269 |
+
else:
|
270 |
+
# fit height
|
271 |
+
scale_width = scale_height
|
272 |
+
elif self.__resize_method == "upper_bound":
|
273 |
+
# scale such that output size is upper bound
|
274 |
+
if scale_width < scale_height:
|
275 |
+
# fit width
|
276 |
+
scale_height = scale_width
|
277 |
+
else:
|
278 |
+
# fit height
|
279 |
+
scale_width = scale_height
|
280 |
+
elif self.__resize_method == "minimal":
|
281 |
+
# scale as least as possbile
|
282 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
283 |
+
# fit width
|
284 |
+
scale_height = scale_width
|
285 |
+
else:
|
286 |
+
# fit height
|
287 |
+
scale_width = scale_height
|
288 |
+
else:
|
289 |
+
raise ValueError(
|
290 |
+
f"resize_method {self.__resize_method} not implemented"
|
291 |
+
)
|
292 |
+
|
293 |
+
if self.__resize_method == "lower_bound":
|
294 |
+
new_height = self.constrain_to_multiple_of(
|
295 |
+
scale_height * height, min_val=self.__height
|
296 |
+
)
|
297 |
+
new_width = self.constrain_to_multiple_of(
|
298 |
+
scale_width * width, min_val=self.__width
|
299 |
+
)
|
300 |
+
elif self.__resize_method == "upper_bound":
|
301 |
+
new_height = self.constrain_to_multiple_of(
|
302 |
+
scale_height * height, max_val=self.__height
|
303 |
+
)
|
304 |
+
new_width = self.constrain_to_multiple_of(
|
305 |
+
scale_width * width, max_val=self.__width
|
306 |
+
)
|
307 |
+
elif self.__resize_method == "minimal":
|
308 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
309 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
310 |
+
else:
|
311 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
312 |
+
|
313 |
+
return (new_width, new_height)
|
314 |
+
|
315 |
+
def make_letter_box(self, sample):
|
316 |
+
top = bottom = (self.__height - sample.shape[0]) // 2
|
317 |
+
left = right = (self.__width - sample.shape[1]) // 2
|
318 |
+
sample = cv2.copyMakeBorder(
|
319 |
+
sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0
|
320 |
+
)
|
321 |
+
return sample
|
322 |
+
|
323 |
+
def __call__(self, sample):
|
324 |
+
width, height = self.get_size(
|
325 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
326 |
+
)
|
327 |
+
|
328 |
+
# resize sample
|
329 |
+
sample["image"] = cv2.resize(
|
330 |
+
sample["image"],
|
331 |
+
(width, height),
|
332 |
+
interpolation=self.__image_interpolation_method,
|
333 |
+
)
|
334 |
+
|
335 |
+
if self.__letter_box:
|
336 |
+
sample["image"] = self.make_letter_box(sample["image"])
|
337 |
+
|
338 |
+
if self.__resize_target:
|
339 |
+
if "disparity" in sample:
|
340 |
+
sample["disparity"] = cv2.resize(
|
341 |
+
sample["disparity"],
|
342 |
+
(width, height),
|
343 |
+
interpolation=cv2.INTER_NEAREST,
|
344 |
+
)
|
345 |
+
|
346 |
+
if self.__letter_box:
|
347 |
+
sample["disparity"] = self.make_letter_box(sample["disparity"])
|
348 |
+
|
349 |
+
if "depth" in sample:
|
350 |
+
sample["depth"] = cv2.resize(
|
351 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
352 |
+
)
|
353 |
+
|
354 |
+
if self.__letter_box:
|
355 |
+
sample["depth"] = self.make_letter_box(sample["depth"])
|
356 |
+
|
357 |
+
sample["mask"] = cv2.resize(
|
358 |
+
sample["mask"].astype(np.float32),
|
359 |
+
(width, height),
|
360 |
+
interpolation=cv2.INTER_NEAREST,
|
361 |
+
)
|
362 |
+
|
363 |
+
if self.__letter_box:
|
364 |
+
sample["mask"] = self.make_letter_box(sample["mask"])
|
365 |
+
|
366 |
+
sample["mask"] = sample["mask"].astype(bool)
|
367 |
+
|
368 |
+
return sample
|